"vscode:/vscode.git/clone" did not exist on "32f661f015c34f7b6bc2131b8d71e13ba296030e"
Commit 7da12e6e authored by Paul's avatar Paul
Browse files

Fuse two reductions

parent 50b1b842
......@@ -218,16 +218,31 @@ struct find_pointwise_reduce
struct find_reduce_pointwise
{
template<class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
}
template<class... Ms>
static auto any_input(Ms... ms)
{
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}
auto matcher() const
{
return match::name("pointwise")(match::any_of[match::inputs()](
match::name("fused_reduce")(match::used_once()).bind("reduce")));
auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce");
auto reduce_input = any_input(reduce);
auto broadcast_reduce_input = any_input(match_broadcast(reduce));
return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto pw = r.result;
auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"];
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
......@@ -235,7 +250,17 @@ struct find_reduce_pointwise
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions
insert_module_in_submodule(rm, reduce, map_ins);
map_ins[reduce] = get_returns(*rm).front();
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = get_returns(*rm).front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = get_returns(*rm).front();
}
auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out);
......@@ -244,14 +269,52 @@ struct find_reduce_pointwise
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match::any_of[match::inputs()](
match::name("fused_reduce")(match::used_once()).bind("reduce")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
if (reduce1->get_operator() != reduce2->get_operator())
return;
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, map_ins);
map_ins[reduce2] = get_returns(*rm).front();
auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
} // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const
{
create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{});
mpm.run_pass(dead_code_elimination{});
for(int i=0;i<4;i++)
{
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -303,6 +303,10 @@ std::string generate_reduce(const module& rm, const std::string& name)
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__"}).set_generic_types(m).set_name(name);
......
......@@ -481,7 +481,7 @@ __device__ void fused_reduce(Output output, F f)
}
else
{
r.outer([&] { output[out_idx] = result; });
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment