Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
45bdaf27
Commit
45bdaf27
authored
Nov 28, 2022
by
Paul
Browse files
Format
parent
2262efe0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
19 deletions
+19
-19
src/targets/gpu/rewrite_ops.cpp
src/targets/gpu/rewrite_ops.cpp
+19
-19
No files found.
src/targets/gpu/rewrite_ops.cpp
View file @
45bdaf27
...
...
@@ -33,42 +33,45 @@ namespace {
MIGRAPHX_PRED_MATCHER
(
col_matrix
,
instruction_ref
ins
)
{
if
(
not
ins
->
get_shape
().
transposed
())
if
(
not
ins
->
get_shape
().
transposed
())
return
false
;
if
(
ins
->
get_shape
().
ndim
()
<
2
)
if
(
ins
->
get_shape
().
ndim
()
<
2
)
return
false
;
auto
perm
=
find_permutation
(
ins
->
get_shape
());
auto
n
=
perm
.
size
()
-
1
;
return
perm
[
n
]
==
n
-
1
and
perm
[
n
-
1
]
==
n
;
return
perm
[
n
]
==
n
-
1
and
perm
[
n
-
1
]
==
n
;
}
MIGRAPHX_PRED_MATCHER
(
broadcast_matrix_dims
,
instruction_ref
ins
)
{
if
(
not
ins
->
get_shape
().
broadcasted
())
if
(
not
ins
->
get_shape
().
broadcasted
())
return
false
;
if
(
ins
->
get_shape
().
ndim
()
<
2
)
if
(
ins
->
get_shape
().
ndim
()
<
2
)
return
false
;
return
std
::
any_of
(
ins
->
get_shape
().
lens
().
rbegin
(),
ins
->
get_shape
().
lens
().
rend
()
+
2
,
[](
auto
i
)
{
return
i
==
0
;
});
return
std
::
any_of
(
ins
->
get_shape
().
lens
().
rbegin
(),
ins
->
get_shape
().
lens
().
rend
()
+
2
,
[](
auto
i
)
{
return
i
==
0
;
});
}
struct
find_dot_const
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
arg
(
1
)(
match
::
is_constant
(),
match
::
none_of
(
col_matrix
(),
broadcast_matrix_dims
()),
match
::
skip_broadcasts
(
match
::
any
().
bind
(
"w"
))))(
match
::
none_of
(
match
::
is_constant
()));
return
match
::
name
(
"dot"
)(
match
::
arg
(
1
)(
match
::
is_constant
(),
match
::
none_of
(
col_matrix
(),
broadcast_matrix_dims
()),
match
::
skip_broadcasts
(
match
::
any
().
bind
(
"w"
))))(
match
::
none_of
(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
w
=
r
.
instructions
[
"w"
];
if
(
w
->
get_shape
().
ndim
()
<
2
)
if
(
w
->
get_shape
().
ndim
()
<
2
)
return
;
auto
perm
=
find_permutation
(
w
->
get_shape
());
auto
n
=
perm
.
size
()
-
1
;
std
::
swap
(
perm
[
n
],
perm
[
n
-
1
]);
std
::
swap
(
perm
[
n
],
perm
[
n
-
1
]);
auto
wl
=
m
.
insert_instruction
(
std
::
next
(
w
),
make_op
(
"layout"
,
{{
"permutation"
,
perm
}}),
w
);
m
.
replace_instruction
(
w
,
wl
);
}
...
...
@@ -76,10 +79,7 @@ struct find_dot_const
}
// namespace
void
rewrite_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_dot_const
{});
}
void
rewrite_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_dot_const
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment