Commit 90911355 authored by Paul's avatar Paul
Browse files

Handle rewriting mul and add

parent 3c160a3f
......@@ -427,10 +427,14 @@ struct find_conv_add
struct find_add_lit_broadcast
{
static auto match_op(const std::string& name)
{
return match::name(name)(
match::either_arg(0, 1)(op_lit_broadcast(name, "a", "x"), lit_broadcast().bind("b")));
}
auto matcher() const
{
return match::name("add")(
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
return match::any_of(match_op("add"), match_op("mul"));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -439,18 +443,23 @@ struct find_add_lit_broadcast
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto op = ins->get_operator();
auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), x_ins, sumab);
auto sumab = m.insert_instruction(ins, op, a_ins, b_ins);
m.replace_instruction(ins, op, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
static auto match_op(const std::string& name)
{
return match::name(name)(
match::args(op_lit_broadcast(name, "a", "x"), op_lit_broadcast(name, "b", "y")));
}
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
return match::any_of(match_op("add"), match_op("mul"));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -460,7 +469,7 @@ struct find_double_add_lit_broadcast
auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto op = ins->get_operator();
instruction_ref sumab;
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
......@@ -469,16 +478,16 @@ struct find_double_add_lit_broadcast
return;
auto op = a_ins->get_operator();
auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
ins, op, a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = m.insert_instruction(ins, op, presum);
}
else
{
sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
sumab = m.insert_instruction(ins, op, a_ins, b_ins);
}
auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins);
m.replace_instruction(ins, make_op("add"), sumxy, sumab);
auto sumxy = m.insert_instruction(ins, op, x_ins, y_ins);
m.replace_instruction(ins, op, sumxy, sumab);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment