"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "cfaf5be69d4b0c9f46654aa4ce2c2a65612a7743"
Commit 8423d683 authored by Paul's avatar Paul
Browse files

Format

parent 7da12e6e
...@@ -218,13 +218,14 @@ struct find_pointwise_reduce ...@@ -218,13 +218,14 @@ struct find_pointwise_reduce
struct find_reduce_pointwise struct find_reduce_pointwise
{ {
template<class... Ms> template <class... Ms>
static auto match_broadcast(Ms... ms) static auto match_broadcast(Ms... ms)
{ {
return match::skip(match::name("contiguous"))(match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast")); return match::skip(match::name("contiguous"))(
match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
} }
template<class... Ms> template <class... Ms>
static auto any_input(Ms... ms) static auto any_input(Ms... ms)
{ {
return match::any_of[match::inputs()](match::any(ms...).bind("input")); return match::any_of[match::inputs()](match::any(ms...).bind("input"));
...@@ -232,7 +233,7 @@ struct find_reduce_pointwise ...@@ -232,7 +233,7 @@ struct find_reduce_pointwise
auto matcher() const auto matcher() const
{ {
auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce"); auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce");
auto reduce_input = any_input(reduce); auto reduce_input = any_input(reduce);
auto broadcast_reduce_input = any_input(match_broadcast(reduce)); auto broadcast_reduce_input = any_input(match_broadcast(reduce));
return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input)); return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input));
...@@ -242,7 +243,7 @@ struct find_reduce_pointwise ...@@ -242,7 +243,7 @@ struct find_reduce_pointwise
{ {
auto pw = r.result; auto pw = r.result;
auto reduce = r.instructions["reduce"]; auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"]; auto input = r.instructions["input"];
const auto* old_rm = reduce->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise"); auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
...@@ -252,10 +253,10 @@ struct find_reduce_pointwise ...@@ -252,10 +253,10 @@ struct find_reduce_pointwise
insert_module_in_submodule(rm, reduce, map_ins); insert_module_in_submodule(rm, reduce, map_ins);
if(contains(r.instructions, "broadcast")) if(contains(r.instructions, "broadcast"))
{ {
auto broadcast = r.instructions["broadcast"]; auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = get_returns(*rm).front(); map_ins[broadcast->inputs().front()] = get_returns(*rm).front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front(); map_ins[input] = bout.front();
} }
else else
{ {
...@@ -283,12 +284,12 @@ struct find_reduce_reduce ...@@ -283,12 +284,12 @@ struct find_reduce_reduce
auto reduce1 = r.result; auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"]; auto reduce2 = r.instructions["reduce"];
if (reduce1->get_operator() != reduce2->get_operator()) if(reduce1->get_operator() != reduce2->get_operator())
return; return;
const auto* rm1 = reduce1->module_inputs().front(); const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front(); const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name()); auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass(); rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
...@@ -310,9 +311,10 @@ void fuse_reduce::apply(module_pass_manager& mpm) const ...@@ -310,9 +311,10 @@ void fuse_reduce::apply(module_pass_manager& mpm) const
{ {
create_reduce_modules(mpm); create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
for(int i=0;i<4;i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{}); match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment