Commit d6764b7d authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

fixup! Simplify algebra for 0*x, x*0 and 0/x operations

parent 41e3b9f6
......@@ -41,8 +41,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
......@@ -859,12 +857,12 @@ struct find_zero_div_const
void apply [[noreturn]] (const module& m, const match::matcher_result& r) const
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
{
m.debug_print();
std::cout << "ERROR:DIV_BY_ZERO: ";
m.debug_print(r.result);
}
// if(enabled(MIGRAPHX_TRACE_MATCHES{}))
//{
// m.debug_print();
// std::cout << "ERROR:DIV_BY_ZERO: ";
// m.debug_print(r.result);
//}
MIGRAPHX_THROW("ERROR: Matched division by zero in pass");
}
};
......@@ -921,21 +919,18 @@ struct find_zero_ops
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f), match::any().bind("x")));
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f), match::any().bind("x")));
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto zero = m.add_literal(0);
auto ret = m.add_return({zero});
auto ins = r.result;
auto zero_ins = r.instructions["x"];
m.remove_instruction((x_ins));
m.replace_instruction(ins, ret);
m.replace_instruction(ins, zero_ins);
}
};
......
......@@ -936,42 +936,47 @@ TEST_CASE(simplify_zero_mult_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("mul"), x, zero);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zero);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0);
m2.add_return({zero});
}
migraphx::module m3;
{
auto zero = m3.add_literal(0);
auto x = m3.add_parameter("x", {migraphx::shape::int32_type, {1}});
m3.add_instruction(migraphx::make_op("mul"), zero, x);
auto x = m3.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m3.add_literal(0);
auto mul_ins = m3.add_instruction(migraphx::make_op("mul"), zero, x);
m3.add_return({mul_ins});
}
run_pass(m3);
EXPECT((m1 == m3) && (m1 == m2));
EXPECT((m1 == m2) && (m3 == m2));
}
TEST_CASE(simplify_zero_div_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("div"), zero, x);
auto zero = m1.add_literal(0);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div_ins = m1.add_instruction(migraphx::make_op("div"), zero, x);
m1.add_return({div_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal(0);
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_return({zero});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment