Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f1c131de
Commit
f1c131de
authored
Jul 07, 2022
by
Paul
Browse files
Format
parent
6f8d1427
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
18 deletions
+15
-18
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+15
-18
No files found.
src/simplify_reshapes.cpp
View file @
f1c131de
...
...
@@ -666,15 +666,12 @@ struct find_slice_transpose
struct
find_dot_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
args
(
match
::
name
(
"dot"
)));
}
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
args
(
match
::
name
(
"dot"
)));
}
template
<
class
Vector
>
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
...
...
@@ -682,7 +679,7 @@ struct find_dot_transpose
return
perm2
==
perm
;
}
template
<
class
Vector
>
template
<
class
Vector
>
static
std
::
size_t
get_batch_elements
(
const
Vector
&
v
)
{
return
std
::
accumulate
(
v
.
begin
(),
v
.
end
()
-
2
,
1
,
std
::
multiplies
<>
{});
...
...
@@ -690,25 +687,25 @@ struct find_dot_transpose
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot
=
ins
->
inputs
().
front
();
auto
am
=
ins
->
inputs
().
front
();
auto
bm
=
ins
->
inputs
().
front
();
auto
ins
=
r
.
result
;
auto
dot
=
ins
->
inputs
().
front
();
auto
am
=
ins
->
inputs
().
front
();
auto
bm
=
ins
->
inputs
().
front
();
auto
transpose
=
any_cast
<
op
::
transpose
>
(
ins
->
get_operator
());
auto
perm
=
transpose
.
dims
;
auto
last
=
perm
.
size
()
-
1
;
auto
perm
=
transpose
.
dims
;
auto
last
=
perm
.
size
()
-
1
;
// Row/column swapped
if
(
is_swapped
(
perm
,
last
-
1
,
last
))
if
(
is_swapped
(
perm
,
last
-
1
,
last
))
{
// Parameters are transposed and flipped
auto
am_t
=
m
.
insert_instruction
(
ins
,
transpose
,
bm
);
auto
bm_t
=
m
.
insert_instruction
(
ins
,
transpose
,
am
);
auto
am_t
=
m
.
insert_instruction
(
ins
,
transpose
,
bm
);
auto
bm_t
=
m
.
insert_instruction
(
ins
,
transpose
,
am
);
auto
new_dot
=
m
.
insert_instruction
(
ins
,
dot
->
get_operator
(),
am_t
,
bm_t
);
m
.
replace_instruction
(
dot
,
new_dot
);
}
else
if
(
is_swapped
(
perm
,
last
-
1
,
last
-
2
))
else
if
(
is_swapped
(
perm
,
last
-
1
,
last
-
2
))
{
if
(
get_batch_elements
(
ins
->
get_shape
().
lens
())
!=
ins
->
get_shape
().
lens
()[
last
-
2
])
if
(
get_batch_elements
(
ins
->
get_shape
().
lens
())
!=
ins
->
get_shape
().
lens
()[
last
-
2
])
return
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment