Commit 49c745bd authored by Paul's avatar Paul
Browse files

Handle literals in add fusions

parent a752db35
......@@ -89,6 +89,8 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const
{
// Run simplifications twice
for(int i=0;i<2;i++)
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{});
}
......
......@@ -225,7 +225,7 @@ struct find_add_relu
return match::name("gpu::relu")(
match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
match::any_of(match::name("@literal"), match::any_of[match::inputs()](match::standard_shape())))
.bind("add")));
}
......@@ -252,7 +252,7 @@ struct find_triadd
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
match::any(match::any_of(match::name("@literal"), match::any_of[match::inputs()](match::standard_shape()))).bind("input")));
}
void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment