Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d294b663
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "437d7ff6f5df4b42ee25dd7136bea4da91739615"
Commit
d294b663
authored
Jun 01, 2023
by
Paul
Browse files
Improve the batch fold check
parent
d8110fc4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+8
-5
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
d294b663
...
@@ -269,12 +269,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -269,12 +269,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b_shape
=
inputs
[
1
];
// cppcheck-suppress unreadVariable
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
auto
rank
=
a_shape
.
lens
().
size
();
return
input
.
broadcasted
();
auto
b_strides
=
b_shape
.
strides
();
}))
return
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
3
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
}
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment