Commit 49c745bd authored by Paul's avatar Paul
Browse files

Handle literals in add fusions

parent a752db35
...@@ -89,7 +89,9 @@ struct find_add_lit_broadcast ...@@ -89,7 +89,9 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(program& p) const
{ {
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}); // Run simplifications twice
for(int i=0;i<2;i++)
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -225,7 +225,7 @@ struct find_add_relu ...@@ -225,7 +225,7 @@ struct find_add_relu
return match::name("gpu::relu")( return match::name("gpu::relu")(
match::arg(0)(match::any_of(match::name("gpu::add"), match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"), match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape())) match::any_of(match::name("@literal"), match::any_of[match::inputs()](match::standard_shape())))
.bind("add"))); .bind("add")));
} }
...@@ -252,7 +252,7 @@ struct find_triadd ...@@ -252,7 +252,7 @@ struct find_triadd
{ {
return match::name("gpu::add")(match::either_arg(0, 1)( return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"), match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input"))); match::any(match::any_of(match::name("@literal"), match::any_of[match::inputs()](match::standard_shape()))).bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment