Commit 8c1b643d authored by Paul's avatar Paul
Browse files

Format

parent 90911355
...@@ -432,10 +432,7 @@ struct find_add_lit_broadcast ...@@ -432,10 +432,7 @@ struct find_add_lit_broadcast
return match::name(name)( return match::name(name)(
match::either_arg(0, 1)(op_lit_broadcast(name, "a", "x"), lit_broadcast().bind("b"))); match::either_arg(0, 1)(op_lit_broadcast(name, "a", "x"), lit_broadcast().bind("b")));
} }
auto matcher() const auto matcher() const { return match::any_of(match_op("add"), match_op("mul")); }
{
return match::any_of(match_op("add"), match_op("mul"));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
...@@ -443,7 +440,7 @@ struct find_add_lit_broadcast ...@@ -443,7 +440,7 @@ struct find_add_lit_broadcast
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto op = ins->get_operator(); auto op = ins->get_operator();
auto sumab = m.insert_instruction(ins, op, a_ins, b_ins); auto sumab = m.insert_instruction(ins, op, a_ins, b_ins);
m.replace_instruction(ins, op, x_ins, sumab); m.replace_instruction(ins, op, x_ins, sumab);
...@@ -457,10 +454,7 @@ struct find_double_add_lit_broadcast ...@@ -457,10 +454,7 @@ struct find_double_add_lit_broadcast
return match::name(name)( return match::name(name)(
match::args(op_lit_broadcast(name, "a", "x"), op_lit_broadcast(name, "b", "y"))); match::args(op_lit_broadcast(name, "a", "x"), op_lit_broadcast(name, "b", "y")));
} }
auto matcher() const auto matcher() const { return match::any_of(match_op("add"), match_op("mul")); }
{
return match::any_of(match_op("add"), match_op("mul"));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
...@@ -469,7 +463,7 @@ struct find_double_add_lit_broadcast ...@@ -469,7 +463,7 @@ struct find_double_add_lit_broadcast
auto y_ins = r.instructions["y"]; auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto op = ins->get_operator(); auto op = ins->get_operator();
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
...@@ -477,8 +471,8 @@ struct find_double_add_lit_broadcast ...@@ -477,8 +471,8 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
auto op = a_ins->get_operator(); auto op = a_ins->get_operator();
auto presum = m.insert_instruction( auto presum =
ins, op, a_ins->inputs().at(0), b_ins->inputs().at(0)); m.insert_instruction(ins, op, a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = m.insert_instruction(ins, op, presum); sumab = m.insert_instruction(ins, op, presum);
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment