Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1e34e1ad
Commit
1e34e1ad
authored
Apr 08, 2023
by
Paul
Browse files
Rewrite reshapes and broadcast
parent
5722eb1b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
0 deletions
+36
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+36
-0
No files found.
src/simplify_reshapes.cpp
View file @
1e34e1ad
...
...
@@ -865,6 +865,41 @@ struct find_reshape_gemm
}
};
struct
find_broadcast_reshaper
{
auto
matcher
()
const
{
auto
broadcast
=
match
::
broadcast_shape
(
match
::
skip
(
match
::
broadcast_shape
())(
match
::
any
().
bind
(
"x"
))).
bind
(
"broadcast"
);
return
match
::
name
(
reshaper_names
())(
match
::
args
(
broadcast
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
broadcast_ins
=
r
.
instructions
[
"broadcast"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
broadcast_shape
=
broadcast_ins
->
get_shape
();
auto
result_shape
=
ins
->
get_shape
();
if
(
std
::
accumulate
(
broadcast_shape
.
strides
().
begin
(),
broadcast_shape
.
strides
().
end
(),
0
)
!=
1
)
return
;
auto
baxis
=
std
::
find
(
broadcast_shape
.
strides
().
begin
(),
broadcast_shape
.
strides
().
end
(),
1
)
-
broadcast_shape
.
strides
().
begin
();
auto
relements
=
result_shape
.
lens
();
std
::
partial_sum
(
relements
.
begin
(),
relements
.
end
(),
relements
.
begin
(),
std
::
multiplies
<>
{});
auto
prefix_elements
=
std
::
accumulate
(
broadcast_shape
.
lens
().
begin
(),
broadcast_shape
.
lens
().
begin
()
+
baxis
+
1
,
1
,
std
::
multiplies
<>
{});
auto
axis
=
std
::
find
(
relements
.
begin
(),
relements
.
end
(),
prefix_elements
)
-
relements
.
begin
();
if
(
axis
>=
relements
.
size
())
return
;
if
(
x_ins
->
get_shape
().
lens
().
size
()
>
1
)
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
),
x_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
x_ins
);
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
...
@@ -874,6 +909,7 @@ void simplify_reshapes::apply(module& m) const
find_resize
{},
find_nop_reshapes
{},
find_reshaper
{},
find_broadcast_reshaper
{},
// find_reshape_cont{},
find_transpose
{},
find_concat_transpose
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment