Commit 9310bff0 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add proper broadcast operands to simplify algebra zero tests

parent f76d7970
...@@ -851,21 +851,21 @@ TEST_CASE(simplify_unit_mult_const2) ...@@ -851,21 +851,21 @@ TEST_CASE(simplify_unit_mult_const2)
TEST_CASE(simplify_unit_mult_const_vec) TEST_CASE(simplify_unit_mult_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {1, 1}}); auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb); m1.add_instruction(migraphx::make_op("mul"), x, unitb);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x); m2.add_instruction(migraphx::make_op("identity"), x);
} }
...@@ -874,21 +874,21 @@ TEST_CASE(simplify_unit_mult_const_vec) ...@@ -874,21 +874,21 @@ TEST_CASE(simplify_unit_mult_const_vec)
TEST_CASE(simplify_unit_mult_const_vec2) TEST_CASE(simplify_unit_mult_const_vec2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {1, 1}}); auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x); m1.add_instruction(migraphx::make_op("mul"), unitb, x);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x); m2.add_instruction(migraphx::make_op("identity"), x);
} }
...@@ -917,21 +917,21 @@ TEST_CASE(simplify_unit_div_const) ...@@ -917,21 +917,21 @@ TEST_CASE(simplify_unit_div_const)
TEST_CASE(simplify_unit_div_const_vec) TEST_CASE(simplify_unit_div_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {1, 1}}); auto unit = m1.add_literal({unit_shape, {1, 1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb); m1.add_instruction(migraphx::make_op("div"), x, unitb);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x); m2.add_instruction(migraphx::make_op("identity"), x);
} }
...@@ -978,21 +978,21 @@ TEST_CASE(simplify_neg_unit_mult_const2) ...@@ -978,21 +978,21 @@ TEST_CASE(simplify_neg_unit_mult_const2)
TEST_CASE(simplify_neg_unit_mul_const_vec) TEST_CASE(simplify_neg_unit_mul_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {-1, -1}}); auto unit = m1.add_literal({unit_shape, {-1, -1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), x, unitb); m1.add_instruction(migraphx::make_op("mul"), x, unitb);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x); m2.add_instruction(migraphx::make_op("neg"), x);
} }
...@@ -1001,21 +1001,21 @@ TEST_CASE(simplify_neg_unit_mul_const_vec) ...@@ -1001,21 +1001,21 @@ TEST_CASE(simplify_neg_unit_mul_const_vec)
TEST_CASE(simplify_neg_unit_mul_const_vec2) TEST_CASE(simplify_neg_unit_mul_const_vec2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {-1, -1}}); auto unit = m1.add_literal({zero_shape, {-1, -1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("mul"), unitb, x); m1.add_instruction(migraphx::make_op("mul"), unitb, x);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x); m2.add_instruction(migraphx::make_op("neg"), x);
} }
...@@ -1043,21 +1043,21 @@ TEST_CASE(simplify_neg_unit_div_const) ...@@ -1043,21 +1043,21 @@ TEST_CASE(simplify_neg_unit_div_const)
TEST_CASE(simplify_neg_unit_div_const_vec) TEST_CASE(simplify_neg_unit_div_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape unit_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto unit = m1.add_literal({inner, {-1, -1}}); auto unit = m1.add_literal({unit_shape, {-1, -1}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto unitb = m1.add_instruction(b, unit); auto unitb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), unit);
m1.add_instruction(migraphx::make_op("div"), x, unitb); m1.add_instruction(migraphx::make_op("div"), x, unitb);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x); m2.add_instruction(migraphx::make_op("neg"), x);
} }
...@@ -1084,21 +1084,21 @@ TEST_CASE(simplify_sub_zero_const) ...@@ -1084,21 +1084,21 @@ TEST_CASE(simplify_sub_zero_const)
TEST_CASE(simplify_sub_zero_const_vec) TEST_CASE(simplify_sub_zero_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto zero = m1.add_literal({inner, {0, 0}}); auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(b, zero); auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m1.add_instruction(migraphx::make_op("sub"), x, zerob); m1.add_instruction(migraphx::make_op("sub"), x, zerob);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("identity"), x); m2.add_instruction(migraphx::make_op("identity"), x);
} }
...@@ -1125,21 +1125,21 @@ TEST_CASE(simplify_sub_neg_zero_const) ...@@ -1125,21 +1125,21 @@ TEST_CASE(simplify_sub_neg_zero_const)
TEST_CASE(simplify_sub_neg_zero_const_vec) TEST_CASE(simplify_sub_neg_zero_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto zero = m1.add_literal({inner, {0, 0}}); auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(b, zero); auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m1.add_instruction(migraphx::make_op("sub"), zerob, x); m1.add_instruction(migraphx::make_op("sub"), zerob, x);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", outer); auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x); m2.add_instruction(migraphx::make_op("neg"), x);
} }
...@@ -1210,14 +1210,14 @@ TEST_CASE(simplify_zero_mult_const2) ...@@ -1210,14 +1210,14 @@ TEST_CASE(simplify_zero_mult_const2)
TEST_CASE(simplify_zero_mul_const_vec) TEST_CASE(simplify_zero_mul_const_vec)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto zero = m1.add_literal({inner, {0, 0}}); auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(b, zero); auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zerob); auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), x, zerob);
m1.add_return({mul_ins}); m1.add_return({mul_ins});
} }
...@@ -1225,9 +1225,10 @@ TEST_CASE(simplify_zero_mul_const_vec) ...@@ -1225,9 +1225,10 @@ TEST_CASE(simplify_zero_mul_const_vec)
migraphx::module m2; migraphx::module m2;
{ {
auto zero = m2.add_literal({inner, {0, 0}}); auto zero = m2.add_literal({zero_shape, {0, 0}});
m2.add_parameter("x", outer); m2.add_parameter("x", x_shape);
auto zerob = m2.add_instruction(b, zero); auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob}); m2.add_return({zerob});
} }
...@@ -1236,14 +1237,14 @@ TEST_CASE(simplify_zero_mul_const_vec) ...@@ -1236,14 +1237,14 @@ TEST_CASE(simplify_zero_mul_const_vec)
TEST_CASE(simplify_zero_mul_const_vec2) TEST_CASE(simplify_zero_mul_const_vec2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto zero = m1.add_literal({inner, {0, 0}}); auto zero = m1.add_literal({zero_shape, {0, 0}});
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", x_shape);
auto zerob = m1.add_instruction(b, zero); auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), zerob, x); auto mul_ins = m1.add_instruction(migraphx::make_op("mul"), zerob, x);
m1.add_return({mul_ins}); m1.add_return({mul_ins});
} }
...@@ -1251,9 +1252,10 @@ TEST_CASE(simplify_zero_mul_const_vec2) ...@@ -1251,9 +1252,10 @@ TEST_CASE(simplify_zero_mul_const_vec2)
migraphx::module m2; migraphx::module m2;
{ {
auto zero = m2.add_literal({inner, {0, 0}}); auto zero = m2.add_literal({zero_shape, {0, 0}});
m2.add_parameter("x", outer); m2.add_parameter("x", x_shape);
auto zerob = m2.add_instruction(b, zero); auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob}); m2.add_return({zerob});
} }
...@@ -1285,12 +1287,12 @@ TEST_CASE(simplify_zero_div_const_vec) ...@@ -1285,12 +1287,12 @@ TEST_CASE(simplify_zero_div_const_vec)
{ {
migraphx::shape zero_shape{migraphx::shape::int32_type, {2}}; migraphx::shape zero_shape{migraphx::shape::int32_type, {2}};
migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape x_shape{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", x_shape); auto x = m1.add_parameter("x", x_shape);
auto zero = m1.add_literal({zero_shape, {0, 0}}); auto zero = m1.add_literal({zero_shape, {0, 0}});
auto zerob = m1.add_instruction(b, zero); auto zerob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
auto div_ins = m1.add_instruction(migraphx::make_op("div"), zerob, x); auto div_ins = m1.add_instruction(migraphx::make_op("div"), zerob, x);
m1.add_return({div_ins}); m1.add_return({div_ins});
} }
...@@ -1300,7 +1302,8 @@ TEST_CASE(simplify_zero_div_const_vec) ...@@ -1300,7 +1302,8 @@ TEST_CASE(simplify_zero_div_const_vec)
{ {
m2.add_parameter("x", x_shape); m2.add_parameter("x", x_shape);
auto zero = m2.add_literal({zero_shape, {0, 0}}); auto zero = m2.add_literal({zero_shape, {0, 0}});
auto zerob = m2.add_instruction(b, zero); auto zerob = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}), zero);
m2.add_return({zerob}); m2.add_return({zerob});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment