Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d34b90f
Commit
6d34b90f
authored
Jun 06, 2022
by
Paul
Browse files
Format
parent
00df057a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
14 deletions
+13
-14
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+13
-14
No files found.
src/simplify_algebra.cpp
View file @
6d34b90f
...
...
@@ -191,9 +191,8 @@ struct find_dot_add
{
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
...
...
@@ -210,9 +209,9 @@ struct find_dot_add
const
bool
flipped
=
a_ins
==
ins
->
inputs
().
back
();
auto
insert_dot
=
[
&
](
auto
x
,
auto
y
)
{
if
(
flipped
)
if
(
flipped
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
y
,
x
);
else
else
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
x
,
y
);
};
...
...
@@ -283,29 +282,29 @@ struct find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast_shape
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)));
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast_shape
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
inputs
=
ins
->
inputs
();
if
(
inputs
.
empty
())
if
(
inputs
.
empty
())
return
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
()))
if
(
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
()))
return
i
->
inputs
().
front
();
else
return
i
;
});
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
&
x
)
{
return
x
->
get_shape
()
==
inputs
.
front
()
->
get_shape
();
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
&
x
)
{
return
x
->
get_shape
()
==
inputs
.
front
()
->
get_shape
();
}))
return
;
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
bop
=
std
::
find_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
i
->
name
());
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment