Commit e77a7947 authored by Paul's avatar Paul
Browse files

Add simplifiy_algebra tests

parent 6f40b531
...@@ -186,9 +186,9 @@ struct nop ...@@ -186,9 +186,9 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
inline migraphx::literal get_2x2() inline migraphx::literal get_2x2(int base=0)
{ {
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}}; return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {base+1, base+2, base+3, base+4}};
} }
inline migraphx::literal get_2x2_transposed() inline migraphx::literal get_2x2_transposed()
......
...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add) ...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast) TEST_CASE(simplify_dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m1.add_literal(get_2x2());
auto two = m1.add_literal(get_2x2(1));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, x);
auto dot = m1.add_instruction(migraphx::make_op("dot"), sum, two);
m1.add_instruction(pass_op{}, dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m2.add_literal(get_2x2());
auto two = m2.add_literal(get_2x2(1));
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x, two);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), one, two);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast) ...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment