Commit d6764b7d authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

fixup! Simplify algebra for 0*x, x*0 and 0/x operations

parent 41e3b9f6
...@@ -41,8 +41,6 @@ ...@@ -41,8 +41,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); } auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); } auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y) auto op_lit_broadcast(std::string op, std::string x, std::string y)
...@@ -859,12 +857,12 @@ struct find_zero_div_const ...@@ -859,12 +857,12 @@ struct find_zero_div_const
void apply [[noreturn]] (const module& m, const match::matcher_result& r) const void apply [[noreturn]] (const module& m, const match::matcher_result& r) const
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) // if(enabled(MIGRAPHX_TRACE_MATCHES{}))
{ //{
m.debug_print(); // m.debug_print();
std::cout << "ERROR:DIV_BY_ZERO: "; // std::cout << "ERROR:DIV_BY_ZERO: ";
m.debug_print(r.result); // m.debug_print(r.result);
} //}
MIGRAPHX_THROW("ERROR: Matched division by zero in pass"); MIGRAPHX_THROW("ERROR: Matched division by zero in pass");
} }
}; };
...@@ -921,21 +919,18 @@ struct find_zero_ops ...@@ -921,21 +919,18 @@ struct find_zero_ops
auto matcher() const auto matcher() const
{ {
auto mul_zero = match::name("mul")( auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f), match::any().bind("x"))); match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero = auto div_zero =
match::name("div")(match::args(match::has_value(0.0f), match::any().bind("x"))); match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero); return match::any_of(mul_zero, div_zero);
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto zero_ins = r.instructions["x"];
auto zero = m.add_literal(0);
auto ret = m.add_return({zero});
m.remove_instruction((x_ins)); m.replace_instruction(ins, zero_ins);
m.replace_instruction(ins, ret);
} }
}; };
......
...@@ -936,42 +936,47 @@ TEST_CASE(simplify_zero_mult_const) ...@@ -936,42 +936,47 @@ TEST_CASE(simplify_zero_mult_const)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0); auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("mul"), x, zero); auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zero);
m1.add_return({mul_ins});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
m2.add_return({zero}); m2.add_return({zero});
} }
migraphx::module m3; migraphx::module m3;
{ {
auto zero = m3.add_literal(0); auto x = m3.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto x = m3.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto zero = m3.add_literal(0);
m3.add_instruction(migraphx::make_op("mul"), zero, x); auto mul_ins = m3.add_instruction(migraphx::make_op("mul"), zero, x);
m3.add_return({mul_ins});
} }
run_pass(m3); run_pass(m3);
EXPECT((m1 == m3) && (m1 == m2)); EXPECT((m1 == m2) && (m3 == m2));
} }
TEST_CASE(simplify_zero_div_const) TEST_CASE(simplify_zero_div_const)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto zero = m1.add_literal(0);
auto zero = m1.add_literal(0); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
m1.add_instruction(migraphx::make_op("div"), zero, x); auto div_ins = m1.add_instruction(migraphx::make_op("div"), zero, x);
m1.add_return({div_ins});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto zero = m2.add_literal(0); auto zero = m2.add_literal(0);
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_return({zero}); m2.add_return({zero});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment