Commit 728fe848 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add in additional vector tests for simplify algebra changes

parent 47dbf164
...@@ -799,6 +799,52 @@ TEST_CASE(simplify_unit_mult_const2) ...@@ -799,6 +799,52 @@ TEST_CASE(simplify_unit_mult_const2)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_unit_mult_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {1, 1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_mult_const_vec2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {1, 1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_unit_div_const) TEST_CASE(simplify_unit_div_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -819,6 +865,29 @@ TEST_CASE(simplify_unit_div_const) ...@@ -819,6 +865,29 @@ TEST_CASE(simplify_unit_div_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_unit_div_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {1, 1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_mult_const) TEST_CASE(simplify_neg_unit_mult_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -857,6 +926,52 @@ TEST_CASE(simplify_neg_unit_mult_const2) ...@@ -857,6 +926,52 @@ TEST_CASE(simplify_neg_unit_mult_const2)
EXPECT((m1 == m2)); EXPECT((m1 == m2));
} }
TEST_CASE(simplify_neg_unit_mul_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {-1, -1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_mul_const_vec2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {-1, -1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_neg_unit_div_const) TEST_CASE(simplify_neg_unit_div_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -876,6 +991,29 @@ TEST_CASE(simplify_neg_unit_div_const) ...@@ -876,6 +991,29 @@ TEST_CASE(simplify_neg_unit_div_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_neg_unit_div_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto unit = m1.add_literal({inner, {-1, -1}});
auto x = m1.add_parameter("x", outer);
auto unitb = m1.add_instruction(b, unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_zero_const) TEST_CASE(simplify_sub_zero_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -894,6 +1032,29 @@ TEST_CASE(simplify_sub_zero_const) ...@@ -894,6 +1032,29 @@ TEST_CASE(simplify_sub_zero_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_sub_zero_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({inner, {0, 0}});
auto x = m1.add_parameter("x", outer);
auto zerob = m1.add_instruction(b, zero);
m1.add_instruction(migraphx::make_op("sub"), x, zerob);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("identity"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_neg_zero_const) TEST_CASE(simplify_sub_neg_zero_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -912,6 +1073,29 @@ TEST_CASE(simplify_sub_neg_zero_const) ...@@ -912,6 +1073,29 @@ TEST_CASE(simplify_sub_neg_zero_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_sub_neg_zero_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({inner, {0, 0}});
auto x = m1.add_parameter("x", outer);
auto zerob = m1.add_instruction(b, zero);
m1.add_instruction(migraphx::make_op("sub"), zerob, x);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", outer);
m2.add_instruction(migraphx::make_op("neg"), x);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_div_zero_const) TEST_CASE(simplify_div_zero_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -996,6 +1180,58 @@ TEST_CASE(simplify_zero_mult_const2) ...@@ -996,6 +1180,58 @@ TEST_CASE(simplify_zero_mult_const2)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_zero_mul_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({inner, {0, 0}});
auto x = m1.add_parameter("x", outer);
auto zerob = m1.add_instruction(b, zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zerob);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal({inner, {0, 0}});
m2.add_parameter("x", outer);
auto zerob = m2.add_instruction(b, zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_mul_const_vec2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto zero = m1.add_literal({inner, {0, 0}});
auto x = m1.add_parameter("x", outer);
auto zerob = m1.add_instruction(b, zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), zerob, x);
m1.add_return({mul_ins});
}
run_pass(m1);
migraphx::module m2;
{
auto zero = m2.add_literal({inner, {0, 0}});
m2.add_parameter("x", outer);
auto zerob = m2.add_instruction(b, zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_zero_div_const) TEST_CASE(simplify_zero_div_const)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -1017,6 +1253,32 @@ TEST_CASE(simplify_zero_div_const) ...@@ -1017,6 +1253,32 @@ TEST_CASE(simplify_zero_div_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_zero_div_const_vec)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", outer);
auto zero = m1.add_literal({inner, {0, 0}});
auto zerob = m1.add_instruction(b, zero);
auto div_ins = m1.add_instruction(migraphx::make_op("div"), zerob, x);
m1.add_return({div_ins});
}
run_pass(m1);
migraphx::module m2;
{
m2.add_parameter("x", outer);
auto zero = m2.add_literal({inner, {0, 0}});
auto zerob = m2.add_instruction(b, zero);
m2.add_return({zerob});
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_rsqrt) TEST_CASE(simplify_rsqrt)
{ {
migraphx::module m1; migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment