Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
39bbf87c
Commit
39bbf87c
authored
May 17, 2022
by
Paul
Browse files
Format
parent
dcd3d04b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
19 deletions
+28
-19
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+28
-19
No files found.
src/simplify_reshapes.cpp
View file @
39bbf87c
...
...
@@ -582,51 +582,60 @@ struct find_transpose_slice
{
auto
matcher
()
const
{
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
output
(
match
::
name
(
"transpose"
)))));
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
output
(
match
::
name
(
"transpose"
)))));
}
static
std
::
vector
<
int64_t
>
find_common_perm
(
const
std
::
vector
<
instruction_ref
>&
transposes
)
{
std
::
map
<
std
::
vector
<
int64_t
>
,
int64_t
>
count
;
for
(
auto
t
:
transposes
)
for
(
auto
t
:
transposes
)
{
auto
perm
=
t
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
count
[
perm
]
++
;
}
return
std
::
max_element
(
count
.
begin
(),
count
.
end
(),
by
(
std
::
less
<>
{},
[](
auto
&&
p
)
{
return
p
.
second
;
}))
->
first
;
count
.
begin
(),
count
.
end
(),
by
(
std
::
less
<>
{},
[](
auto
&&
p
)
{
return
p
.
second
;
}))
->
first
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
std
::
vector
<
instruction_ref
>
splits
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
splits
),
[
&
](
instruction_ref
out
)
{
return
out
->
name
()
==
"slice"
and
out
->
outputs
().
size
()
==
1
and
out
->
outputs
().
front
()
->
name
()
==
"transpose"
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
splits
),
[
&
](
instruction_ref
out
)
{
return
out
->
name
()
==
"slice"
and
out
->
outputs
().
size
()
==
1
and
out
->
outputs
().
front
()
->
name
()
==
"transpose"
;
});
if
(
splits
.
size
()
<
2
)
if
(
splits
.
size
()
<
2
)
return
;
std
::
vector
<
instruction_ref
>
transposes
;
std
::
transform
(
splits
.
begin
(),
splits
.
end
(),
std
::
back_inserter
(
transposes
),
[](
auto
split
)
{
return
split
->
outputs
().
front
();
});
std
::
transform
(
splits
.
begin
(),
splits
.
end
(),
std
::
back_inserter
(
transposes
),
[](
auto
split
)
{
return
split
->
outputs
().
front
();
});
auto
perm
=
find_common_perm
(
transposes
);
auto
iperm
=
invert_permutation
(
perm
);
auto
pre
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ins
);
for
(
auto
i
:
range
(
transposes
.
size
()))
auto
pre
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ins
);
for
(
auto
i
:
range
(
transposes
.
size
()))
{
auto
split
=
splits
[
i
];
auto
t
=
transposes
[
i
];
auto
op
=
any_cast
<
op
::
slice
>
(
split
->
get_operator
());
for
(
auto
&
axis
:
op
.
axes
)
for
(
auto
&
axis
:
op
.
axes
)
{
axis
=
iperm
[
axis
];
}
auto
new_ins
=
m
.
insert_instruction
(
t
,
op
,
pre
);
if
(
t
->
get_operator
()
!=
pre
->
get_operator
())
if
(
t
->
get_operator
()
!=
pre
->
get_operator
())
{
auto
curr
=
t
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
new_ins
=
m
.
insert_instruction
(
t
,
make_op
(
"transpose"
,
{{
"permutation"
,
reorder_dims
(
iperm
,
curr
)}}),
new_ins
);
new_ins
=
m
.
insert_instruction
(
t
,
make_op
(
"transpose"
,
{{
"permutation"
,
reorder_dims
(
iperm
,
curr
)}}),
new_ins
);
}
m
.
replace_instruction
(
t
,
new_ins
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment