Commit 90911355 authored by Paul's avatar Paul
Browse files

Handle rewriting mul and add

parent 3c160a3f
...@@ -427,10 +427,14 @@ struct find_conv_add ...@@ -427,10 +427,14 @@ struct find_conv_add
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
static auto match_op(const std::string& name)
{
return match::name(name)(
match::either_arg(0, 1)(op_lit_broadcast(name, "a", "x"), lit_broadcast().bind("b")));
}
auto matcher() const auto matcher() const
{ {
return match::name("add")( return match::any_of(match_op("add"), match_op("mul"));
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -439,18 +443,23 @@ struct find_add_lit_broadcast ...@@ -439,18 +443,23 @@ struct find_add_lit_broadcast
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto op = ins->get_operator();
auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins); auto sumab = m.insert_instruction(ins, op, a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), x_ins, sumab); m.replace_instruction(ins, op, x_ins, sumab);
} }
}; };
struct find_double_add_lit_broadcast struct find_double_add_lit_broadcast
{ {
static auto match_op(const std::string& name)
{
return match::name(name)(
match::args(op_lit_broadcast(name, "a", "x"), op_lit_broadcast(name, "b", "y")));
}
auto matcher() const auto matcher() const
{ {
return match::name("add")( return match::any_of(match_op("add"), match_op("mul"));
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -460,7 +469,7 @@ struct find_double_add_lit_broadcast ...@@ -460,7 +469,7 @@ struct find_double_add_lit_broadcast
auto y_ins = r.instructions["y"]; auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto op = ins->get_operator();
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
...@@ -469,16 +478,16 @@ struct find_double_add_lit_broadcast ...@@ -469,16 +478,16 @@ struct find_double_add_lit_broadcast
return; return;
auto op = a_ins->get_operator(); auto op = a_ins->get_operator();
auto presum = m.insert_instruction( auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0)); ins, op, a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = m.insert_instruction(ins, op, presum); sumab = m.insert_instruction(ins, op, presum);
} }
else else
{ {
sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins); sumab = m.insert_instruction(ins, op, a_ins, b_ins);
} }
auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins); auto sumxy = m.insert_instruction(ins, op, x_ins, y_ins);
m.replace_instruction(ins, make_op("add"), sumxy, sumab); m.replace_instruction(ins, op, sumxy, sumab);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment