Unverified Commit 4f3cc417 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Simplify unit algebraic ops (#1281)

Simplified algebraic operations (x*1), x*(-1), x/1, 0+x & x+0,  x-0, 0-x, 0*x, x*0, and 0/x operations
parent f7d987ba
......@@ -933,6 +933,73 @@ struct find_div_const
}
};
struct find_unit_ops
{
auto matcher() const
{
auto mul_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x")));
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1e-12), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
return match::any_of(mul_1, div_1, add_0, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
m.replace_instruction(ins, c_in);
}
};
struct find_neg_unit_ops
{
auto matcher() const
{
auto mul_neg_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x")));
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in);
m.replace_instruction(ins, neg);
}
};
struct find_zero_ops
{
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto zero_ins = r.instructions["x"];
m.replace_instruction(ins, zero_ins);
}
};
struct find_sub_const
{
auto matcher() const
......@@ -1149,6 +1216,9 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{},
find_div_const{},
find_sub_const{},
......
......@@ -122,6 +122,33 @@ TEST_CASE(simplify_add3)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_add_constant)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("add"), zero, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("identity"), x);
}
migraphx::module m3;
{
auto x = m3.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m3.add_literal(0);
m3.add_instruction(migraphx::make_op("add"), x, zero);
}
run_pass(m3);
EXPECT((m1 == m2) && (m2 == m3));
}
TEST_CASE(simplify_add_broadcast1)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
......@@ -435,7 +462,7 @@ TEST_CASE(simplify_mul_add)
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = m1.add_literal(1);
auto one = m1.add_literal(3);
auto two = m1.add_literal(2);
auto sum = m1.add_instruction(migraphx::make_op("add"), one, x);
auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two);
......@@ -446,7 +473,7 @@ TEST_CASE(simplify_mul_add)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = m2.add_literal(1);
auto one = m2.add_literal(3);
auto two = m2.add_literal(2);
auto mul1 = m2.add_instruction(migraphx::make_op("mul"), two, x);
auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one);
......@@ -883,6 +910,341 @@ TEST_CASE(simplify_div_const)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_mult_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(1);
m1.add_instruction(migraphx::make_op("mul"), x, unit);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_mult_const2)
{
migraphx::module m1;
{
auto unit = m1.add_literal(1);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
m1.add_instruction(migraphx::make_op("mul"), unit, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_mult_const_vec)
{
migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_mult_const_vec2)
{
migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_div_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(1);
auto div = m1.add_instruction(migraphx::make_op("div"), x, unit);
m1.add_return({div});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_return({x});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_div_const_vec)
{
migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_mult_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(-1);
m1.add_instruction(migraphx::make_op("mul"), x, unit);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT((m1 == m2));
}
TEST_CASE(simplify_neg_unit_mult_const2)
{
migraphx::module m1;
{
auto unit = m1.add_literal(-1);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
m1.add_instruction(migraphx::make_op("mul"), unit, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT((m1 == m2));
}
TEST_CASE(simplify_neg_unit_mul_const_vec)
{
migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({unit_shape, {-1, -1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_mul_const_vec2)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({zero_shape, {-1, -1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_div_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(-1);
m1.add_instruction(migraphx::make_op("div"), x, unit);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_div_const_vec)
{
migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({unit_shape, {-1, -1}});
auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_zero_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("sub"), x, zero);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_zero_const_vec)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m1.add_instruction(migraphx::make_op("sub"), x, zerob);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_neg_zero_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
m1.add_instruction(migraphx::make_op("sub"), zero, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_neg_zero_const_vec)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m1.add_instruction(migraphx::make_op("sub"), zerob, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_const)
{
migraphx::module m1;
......@@ -903,6 +1265,150 @@ TEST_CASE(simplify_sub_const)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_mult_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zero);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0);
m2.add_return({zero});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_mult_const2)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m1.add_literal(0);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), zero, x);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto zero = m2.add_literal(0);
m2.add_return({zero});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_mul_const_vec)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zerob);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal({zero_shape, {0, 0}});
m2.add_parameter("x", x_shape);
auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_mul_const_vec2)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), zerob, x);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal({zero_shape, {0, 0}});
m2.add_parameter("x", x_shape);
auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_div_const)
{
migraphx::module m1;
{
auto zero = m1.add_literal(0);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto div_ins = m1.add_instruction(migraphx::make_op("div"), zero, x);
m1.add_return({div_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal(0);
m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_return({zero});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_div_const_vec)
{
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", x_shape);
auto zero = m1.add_literal({zero_shape, {0, 0}});
auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto div_ins = m1.add_instruction(migraphx::make_op("div"), zerob, x);
m1.add_return({div_ins});
}
run_pass(m1);
migraphx::module m2;
{
m2.add_parameter("x", x_shape);
auto zero = m2.add_literal({zero_shape, {0, 0}});
auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_rsqrt)
{
migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment