Commit 7da12e6e authored by Paul's avatar Paul
Browse files

Fuse two reductions

parent 50b1b842
...@@ -218,16 +218,31 @@ struct find_pointwise_reduce ...@@ -218,16 +218,31 @@ struct find_pointwise_reduce
struct find_reduce_pointwise struct find_reduce_pointwise
{ {
template<class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
}
template<class... Ms>
static auto any_input(Ms... ms)
{
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}
auto matcher() const auto matcher() const
{ {
return match::name("pointwise")(match::any_of[match::inputs()]( auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce");
match::name("fused_reduce")(match::used_once()).bind("reduce"))); auto reduce_input = any_input(reduce);
auto broadcast_reduce_input = any_input(match_broadcast(reduce));
return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto pw = r.result; auto pw = r.result;
auto reduce = r.instructions["reduce"]; auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"];
const auto* old_rm = reduce->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise"); auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
...@@ -235,7 +250,17 @@ struct find_reduce_pointwise ...@@ -235,7 +250,17 @@ struct find_reduce_pointwise
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions // Copy module instructions
insert_module_in_submodule(rm, reduce, map_ins); insert_module_in_submodule(rm, reduce, map_ins);
map_ins[reduce] = get_returns(*rm).front(); if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = get_returns(*rm).front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = get_returns(*rm).front();
}
auto out = insert_ins_in_submodule(rm, pw, map_ins); auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out); rm->replace_return(out);
...@@ -244,14 +269,52 @@ struct find_reduce_pointwise ...@@ -244,14 +269,52 @@ struct find_reduce_pointwise
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
} }
}; };
struct find_reduce_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match::any_of[match::inputs()](
match::name("fused_reduce")(match::used_once()).bind("reduce")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
if (reduce1->get_operator() != reduce2->get_operator())
return;
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, map_ins);
map_ins[reduce2] = get_returns(*rm).front();
auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
} // namespace } // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const void fuse_reduce::apply(module_pass_manager& mpm) const
{ {
create_reduce_modules(mpm); create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}); for(int i=0;i<4;i++)
{
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -303,6 +303,10 @@ std::string generate_reduce(const module& rm, const std::string& name) ...@@ -303,6 +303,10 @@ std::string generate_reduce(const module& rm, const std::string& name)
{"args", join_strings(args, ", ")}, {"args", join_strings(args, ", ")},
{"call", call_function}}); {"call", call_function}});
} }
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name()); MIGRAPHX_THROW("Unknown operator: " + ins->name());
}); });
f.set_attributes({"__device__"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__"}).set_generic_types(m).set_name(name);
......
...@@ -481,7 +481,7 @@ __device__ void fused_reduce(Output output, F f) ...@@ -481,7 +481,7 @@ __device__ void fused_reduce(Output output, F f)
} }
else else
{ {
r.outer([&] { output[out_idx] = result; }); r.outer([&] { output[out_idx] = implicit_conversion(result); });
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment