Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c27a6376
Commit
c27a6376
authored
Jul 02, 2022
by
Paul
Browse files
Const fold adds for gemms
parent
cf9cec1c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+10
-2
No files found.
src/targets/gpu/fuse_ops.cpp
View file @
c27a6376
...
...
@@ -50,6 +50,7 @@
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <set>
...
...
@@ -974,9 +975,8 @@ struct find_gemm_pointwise
{
return
precompile_name
(
"pointwise"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
any
(
).
bind
(
"c"
),
match
::
any
_of
(
match
::
standard_shape
(),
match
::
is_constant
()
).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
}
...
...
@@ -1052,6 +1052,14 @@ struct find_gemm_pointwise
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
op
::
contiguous
{};
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment