Commit 8423d683 authored by Paul's avatar Paul
Browse files

Format

parent 7da12e6e
......@@ -218,13 +218,14 @@ struct find_pointwise_reduce
struct find_reduce_pointwise
{
template<class... Ms>
template <class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
return match::skip(match::name("contiguous"))(
match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
}
template<class... Ms>
template <class... Ms>
static auto any_input(Ms... ms)
{
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
......@@ -232,7 +233,7 @@ struct find_reduce_pointwise
auto matcher() const
{
auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce");
auto reduce = match::name("fused_reduce")(match::used_once()).bind("reduce");
auto reduce_input = any_input(reduce);
auto broadcast_reduce_input = any_input(match_broadcast(reduce));
return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input));
......@@ -242,7 +243,7 @@ struct find_reduce_pointwise
{
auto pw = r.result;
auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"];
auto input = r.instructions["input"];
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
......@@ -252,10 +253,10 @@ struct find_reduce_pointwise
insert_module_in_submodule(rm, reduce, map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = get_returns(*rm).front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
}
else
{
......@@ -283,12 +284,12 @@ struct find_reduce_reduce
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
if (reduce1->get_operator() != reduce2->get_operator())
if(reduce1->get_operator() != reduce2->get_operator())
return;
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
......@@ -310,9 +311,10 @@ void fuse_reduce::apply(module_pass_manager& mpm) const
{
create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i=0;i<4;i++)
for(int i = 0; i < 4; i++)
{
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment