Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6f8d1427
Commit
6f8d1427
authored
Jul 07, 2022
by
Paul
Browse files
Add find_dot_transpose
parent
f2531606
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
0 deletions
+50
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+50
-0
No files found.
src/simplify_reshapes.cpp
View file @
6f8d1427
...
@@ -664,6 +664,56 @@ struct find_slice_transpose
...
@@ -664,6 +664,56 @@ struct find_slice_transpose
}
}
};
};
struct
find_dot_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
args
(
match
::
name
(
"dot"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
template
<
class
Vector
>
static
std
::
size_t
get_batch_elements
(
const
Vector
&
v
)
{
return
std
::
accumulate
(
v
.
begin
(),
v
.
end
()
-
2
,
1
,
std
::
multiplies
<>
{});
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot
=
ins
->
inputs
().
front
();
auto
am
=
ins
->
inputs
().
front
();
auto
bm
=
ins
->
inputs
().
front
();
auto
transpose
=
any_cast
<
op
::
transpose
>
(
ins
->
get_operator
());
auto
perm
=
transpose
.
dims
;
auto
last
=
perm
.
size
()
-
1
;
// Row/column swapped
if
(
is_swapped
(
perm
,
last
-
1
,
last
))
{
// Parameters are transposed and flipped
auto
am_t
=
m
.
insert_instruction
(
ins
,
transpose
,
bm
);
auto
bm_t
=
m
.
insert_instruction
(
ins
,
transpose
,
am
);
auto
new_dot
=
m
.
insert_instruction
(
ins
,
dot
->
get_operator
(),
am_t
,
bm_t
);
m
.
replace_instruction
(
dot
,
new_dot
);
}
else
if
(
is_swapped
(
perm
,
last
-
1
,
last
-
2
))
{
if
(
get_batch_elements
(
ins
->
get_shape
().
lens
())
!=
ins
->
get_shape
().
lens
()[
last
-
2
])
return
;
}
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment