Commit 7422984c authored by Paul's avatar Paul
Browse files

Fuse pointwise across broadcasts

parent d78bcdfb
...@@ -173,12 +173,65 @@ static bool find_pointwise_modules(module& m) ...@@ -173,12 +173,65 @@ static bool find_pointwise_modules(module& m)
return changed; return changed;
} }
static instruction_ref find_broadcasted_pointwise(instruction_ref ins, std::vector<operation>& ops)
{
if(ins->outputs().size() != 1)
return ins;
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
{
ops.push_back(ins->get_operator());
return find_broadcasted_pointwise(ins->inputs().front(), ops);
}
return ins;
}
static void remove_broadcasts(module& m)
{
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty() and ins != last)
continue;
auto inputs = ins->inputs();
for(auto input : inputs)
{
if(input->outputs().size() != 1)
continue;
if(input->name() == "pointwise")
continue;
std::vector<operation> ops;
auto pins = find_broadcasted_pointwise(input, ops);
if(ops.empty())
continue;
if(pins->name() != "pointwise")
continue;
if(pins->outputs().size() != 1)
continue;
auto pinputs = pins->inputs();
std::transform(pinputs.begin(), pinputs.end(), pinputs.begin(), [&](auto x) {
for(auto op : ops)
{
x = m.insert_instruction(pins, op, x);
}
return x;
});
auto nins =
m.insert_instruction(pins, pins->get_operator(), pinputs, pins->module_inputs());
m.replace_instruction(input, nins);
}
}
}
void fuse_pointwise::apply(module_pass_manager& mpm) const void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
create_pointwise_modules(mpm); create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
remove_broadcasts(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
break; break;
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment