Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7422984c
Commit
7422984c
authored
Sep 11, 2022
by
Paul
Browse files
Fuse pointwise across broadcasts
parent
d78bcdfb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
0 deletions
+53
-0
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+53
-0
No files found.
src/fuse_pointwise.cpp
View file @
7422984c
...
@@ -173,12 +173,65 @@ static bool find_pointwise_modules(module& m)
...
@@ -173,12 +173,65 @@ static bool find_pointwise_modules(module& m)
return
changed
;
return
changed
;
}
}
static
instruction_ref
find_broadcasted_pointwise
(
instruction_ref
ins
,
std
::
vector
<
operation
>&
ops
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
ins
;
if
(
contains
({
"contiguous"
,
"broadcast"
,
"multibroadcast"
},
ins
->
name
()))
{
ops
.
push_back
(
ins
->
get_operator
());
return
find_broadcasted_pointwise
(
ins
->
inputs
().
front
(),
ops
);
}
return
ins
;
}
static
void
remove_broadcasts
(
module
&
m
)
{
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"pointwise"
)
continue
;
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
auto
inputs
=
ins
->
inputs
();
for
(
auto
input
:
inputs
)
{
if
(
input
->
outputs
().
size
()
!=
1
)
continue
;
if
(
input
->
name
()
==
"pointwise"
)
continue
;
std
::
vector
<
operation
>
ops
;
auto
pins
=
find_broadcasted_pointwise
(
input
,
ops
);
if
(
ops
.
empty
())
continue
;
if
(
pins
->
name
()
!=
"pointwise"
)
continue
;
if
(
pins
->
outputs
().
size
()
!=
1
)
continue
;
auto
pinputs
=
pins
->
inputs
();
std
::
transform
(
pinputs
.
begin
(),
pinputs
.
end
(),
pinputs
.
begin
(),
[
&
](
auto
x
)
{
for
(
auto
op
:
ops
)
{
x
=
m
.
insert_instruction
(
pins
,
op
,
x
);
}
return
x
;
});
auto
nins
=
m
.
insert_instruction
(
pins
,
pins
->
get_operator
(),
pinputs
,
pins
->
module_inputs
());
m
.
replace_instruction
(
input
,
nins
);
}
}
}
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
create_pointwise_modules
(
mpm
);
create_pointwise_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
remove_broadcasts
(
mpm
.
get_module
());
mpm
.
run_pass
(
dead_code_elimination
{});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment