Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7da12e6e
"vscode:/vscode.git/clone" did not exist on "32f661f015c34f7b6bc2131b8d71e13ba296030e"
Commit
7da12e6e
authored
Jan 25, 2023
by
Paul
Browse files
Fuse two reductions
parent
50b1b842
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
6 deletions
+73
-6
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+68
-5
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+4
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+1
-1
No files found.
src/fuse_reduce.cpp
View file @
7da12e6e
...
...
@@ -218,16 +218,31 @@ struct find_pointwise_reduce
struct
find_reduce_pointwise
{
template
<
class
...
Ms
>
static
auto
match_broadcast
(
Ms
...
ms
)
{
return
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"multibroadcast"
)(
match
::
arg
(
0
)(
ms
...)).
bind
(
"broadcast"
));
}
template
<
class
...
Ms
>
static
auto
any_input
(
Ms
...
ms
)
{
return
match
::
any_of
[
match
::
inputs
()](
match
::
any
(
ms
...).
bind
(
"input"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
auto
reduce
=
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
);
auto
reduce_input
=
any_input
(
reduce
);
auto
broadcast_reduce_input
=
any_input
(
match_broadcast
(
reduce
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
(
reduce_input
,
broadcast_reduce_input
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
...
...
@@ -235,7 +250,17 @@ struct find_reduce_pointwise
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
map_ins
[
reduce
]
=
get_returns
(
*
rm
).
front
();
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
get_returns
(
*
rm
).
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
get_returns
(
*
rm
).
front
();
}
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
...
...
@@ -244,14 +269,52 @@ struct find_reduce_pointwise
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce1
=
r
.
result
;
auto
reduce2
=
r
.
instructions
[
"reduce"
];
if
(
reduce1
->
get_operator
()
!=
reduce2
->
get_operator
())
return
;
const
auto
*
rm1
=
reduce1
->
module_inputs
().
front
();
const
auto
*
rm2
=
reduce2
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
rm1
->
name
()
+
":"
+
rm2
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy reduce1 instructions
insert_module_in_submodule
(
rm
,
reduce2
,
map_ins
);
map_ins
[
reduce2
]
=
get_returns
(
*
rm
).
front
();
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{},
find_reduce_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/compile_gen.cpp
View file @
7da12e6e
...
...
@@ -303,6 +303,10 @@ std::string generate_reduce(const module& rm, const std::string& name)
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
else
if
(
ins
->
name
()
==
"multibroadcast"
)
{
return
names
.
at
(
ins
->
inputs
().
front
());
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
f
.
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
7da12e6e
...
...
@@ -481,7 +481,7 @@ __device__ void fused_reduce(Output output, F f)
}
else
{
r
.
outer
([
&
]
{
output
[
out_idx
]
=
result
;
});
r
.
outer
([
&
]
{
output
[
out_idx
]
=
implicit_conversion
(
result
)
;
});
}
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment