Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7da12e6e
Commit
7da12e6e
authored
Jan 25, 2023
by
Paul
Browse files
Fuse two reductions
parent
50b1b842
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
6 deletions
+73
-6
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+68
-5
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+4
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+1
-1
No files found.
src/fuse_reduce.cpp
View file @
7da12e6e
...
...
@@ -218,16 +218,31 @@ struct find_pointwise_reduce
struct
find_reduce_pointwise
{
template
<
class
...
Ms
>
static
auto
match_broadcast
(
Ms
...
ms
)
{
return
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"multibroadcast"
)(
match
::
arg
(
0
)(
ms
...)).
bind
(
"broadcast"
));
}
template
<
class
...
Ms
>
static
auto
any_input
(
Ms
...
ms
)
{
return
match
::
any_of
[
match
::
inputs
()](
match
::
any
(
ms
...).
bind
(
"input"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
auto
reduce
=
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
);
auto
reduce_input
=
any_input
(
reduce
);
auto
broadcast_reduce_input
=
any_input
(
match_broadcast
(
reduce
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
(
reduce_input
,
broadcast_reduce_input
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
...
...
@@ -235,7 +250,17 @@ struct find_reduce_pointwise
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
map_ins
[
reduce
]
=
get_returns
(
*
rm
).
front
();
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
get_returns
(
*
rm
).
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
get_returns
(
*
rm
).
front
();
}
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
...
...
@@ -244,14 +269,52 @@ struct find_reduce_pointwise
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce1
=
r
.
result
;
auto
reduce2
=
r
.
instructions
[
"reduce"
];
if
(
reduce1
->
get_operator
()
!=
reduce2
->
get_operator
())
return
;
const
auto
*
rm1
=
reduce1
->
module_inputs
().
front
();
const
auto
*
rm2
=
reduce2
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
rm1
->
name
()
+
":"
+
rm2
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy reduce1 instructions
insert_module_in_submodule
(
rm
,
reduce2
,
map_ins
);
map_ins
[
reduce2
]
=
get_returns
(
*
rm
).
front
();
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{},
find_reduce_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/compile_gen.cpp
View file @
7da12e6e
...
...
@@ -303,6 +303,10 @@ std::string generate_reduce(const module& rm, const std::string& name)
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
else
if
(
ins
->
name
()
==
"multibroadcast"
)
{
return
names
.
at
(
ins
->
inputs
().
front
());
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
f
.
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
7da12e6e
...
...
@@ -481,7 +481,7 @@ __device__ void fused_reduce(Output output, F f)
}
else
{
r
.
outer
([
&
]
{
output
[
out_idx
]
=
result
;
});
r
.
outer
([
&
]
{
output
[
out_idx
]
=
implicit_conversion
(
result
)
;
});
}
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment