Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
49dc6d12
Commit
49dc6d12
authored
Apr 03, 2023
by
Paul
Browse files
Match mul_add reshapes
parent
ce2423ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
2 deletions
+35
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+35
-2
No files found.
src/simplify_reshapes.cpp
View file @
49dc6d12
...
...
@@ -587,13 +587,14 @@ struct find_reshape_cont
};
// match sequence of transpose --> contiguous --> reshaper_op
auto
match_transpose_contiguous_reshaper
()
template
<
class
...
Ms
>
auto
match_transpose_contiguous_reshaper
(
Ms
...
ms
)
{
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
used_once
(),
match
::
args
(
match
::
name
(
"contiguous"
)(
match
::
used_once
(),
match
::
args
(
match
::
transpose_shape
().
bind
(
"trans_ins"
)))
match
::
used_once
(),
match
::
args
(
match
::
transpose_shape
(
ms
...
).
bind
(
"trans_ins"
)))
.
bind
(
"cont_ins"
)))
.
bind
(
"reshaper_ins"
);
};
...
...
@@ -626,6 +627,37 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct
find_mul_add_transpose_contiguous_reshaper_gemm
{
auto
matcher
()
const
{
auto
pw
=
match
::
name
(
"mul"
,
"add"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"c"
),
match
::
any
().
bind
(
"x"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match_transpose_contiguous_reshaper
(
pw
.
bind
(
"pointwise"
)),
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
pw_ins
=
r
.
instructions
[
"pointwise"
];
auto
insert_reshapes
=
[
&
](
auto
x
)
{
auto
t
=
m
.
insert_instruction
(
ins
,
trans_ins
->
get_operator
(),
x
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
t
);
return
m
.
insert_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
c
);
};
if
(
x_ins
->
name
()
==
"mul"
)
{
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
{
insert_reshapes
(
x_ins
->
inputs
()[
0
]),
insert_reshapes
(
x_ins
->
inputs
()[
1
])});
}
auto
y_ins
=
m
.
insert_instruction
(
ins
,
pw_ins
->
get_operator
(),
{
x_ins
,
insert_reshapes
(
c_ins
)});
m
.
replace_instruction
(
reshaper_ins
,
y_ins
);
}
};
struct
find_slice_transpose
{
auto
matcher
()
const
...
...
@@ -844,6 +876,7 @@ void simplify_reshapes::apply(module& m) const
find_transpose_slice
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{},
find_mul_add_transpose_contiguous_reshaper_gemm
{},
find_reshape_gemm
{});
dead_code_elimination
{}.
apply
(
m
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment