Commit c7b4fadd authored by Paul's avatar Paul
Browse files

Add mul add simplifications

parent 09357b62
...@@ -17,14 +17,18 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y) ...@@ -17,14 +17,18 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
not_lit_broadcast().bind(std::move(y)))); not_lit_broadcast().bind(std::move(y))));
} }
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul")(match::either_arg(0, 1)( return match::name("mul")(match::either_arg(0, 1)(
match::name("convolution")(match::used_once(), conv_const_weights().bind("conv"),
match::args(match::any(), match::is_constant().bind("w")))
.bind("conv"),
match::name("broadcast").bind("a"))); match::name("broadcast").bind("a")));
} }
...@@ -48,6 +52,30 @@ struct find_mul_conv ...@@ -48,6 +52,30 @@ struct find_mul_conv
} }
}; };
struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(match::either_arg(0, 1)(match::any().bind("x"), match::any_of(conv_const_weights(), match::is_constant()).bind("y"))),
match::is_constant().bind("a")
));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xa_ins = p.insert_instruction(ins, op::mul{}, x_ins, a_ins);
auto ya_ins = p.insert_instruction(ins, op::mul{}, y_ins, a_ins);
auto sum_xa_ya = p.insert_instruction(ins, op::add{}, xa_ins, ya_ins);
p.replace_instruction(ins, sum_xa_ya);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -89,9 +117,9 @@ struct find_add_lit_broadcast ...@@ -89,9 +117,9 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(program& p) const
{ {
// Run simplifications twice // Run simplifications multiple times
for(int i = 0; i < 2; i++) for(int i = 0; i < 4; i++)
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}); match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}, find_mul_add{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment