#include #include #include #include #include struct simplify_algebra_target { std::string name() const { return "simplify_algebra"; } std::vector get_passes(migraphx::context&) const { return {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}; } migraphx::context get_context() const { return {}; } }; TEST_CASE(simplify_add1) { migraphx::program p1; { auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p1.add_literal(1); auto two = p1.add_literal(2); auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one); auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); p1.add_instruction(pass_op{}, sum3); } p1.compile(simplify_algebra_target{}); migraphx::program p2; { auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p2.add_literal(1); auto two = p2.add_literal(2); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); p2.add_instruction(pass_op{}, sum3); } EXPECT(p1 == p2); } TEST_CASE(simplify_add2) { migraphx::program p1; { auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p1.add_literal(1); auto two = p1.add_literal(2); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); auto sum2 = p1.add_instruction(migraphx::op::add{}, two, y); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); p1.add_instruction(pass_op{}, sum3); } p1.compile(simplify_algebra_target{}); migraphx::program p2; { auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p2.add_literal(1); auto two = p2.add_literal(2); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); p2.add_instruction(pass_op{}, sum3); } EXPECT(p1 == p2); } TEST_CASE(simplify_add3) { migraphx::program p1; { auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto one = p1.add_literal(1); auto two = p1.add_literal(2); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); p1.add_instruction(pass_op{}, sum3); } p1.compile(simplify_algebra_target{}); migraphx::program p2; { auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto one = p2.add_literal(1); auto two = p2.add_literal(2); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, x); auto sum2 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); p2.add_instruction(pass_op{}, sum3); } EXPECT(p1 == p2); } TEST_CASE(simplify_add_broadcast1) { migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::op::broadcast b{1, {1, 2, 3, 3}}; migraphx::program p1; { auto x = p1.add_parameter("x", outer); auto y = p1.add_parameter("y", outer); auto one = p1.add_literal({inner, {1, 1}}); auto oneb = p1.add_instruction(b, one); auto two = p1.add_literal({inner, {2, 2}}); auto twob = p1.add_instruction(b, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); p1.add_instruction(pass_op{}, sum3); } p1.compile(simplify_algebra_target{}); migraphx::program p2; { auto x = p2.add_parameter("x", outer); auto y = p2.add_parameter("y", outer); auto one = p2.add_literal({inner, {1, 1}}); auto two = p2.add_literal({inner, {2, 2}}); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum1b = p2.add_instruction(b, sum1); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1b); p2.add_instruction(pass_op{}, sum3); } EXPECT(p1 == p2); } TEST_CASE(simplify_add_broadcast2) { migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto create_program = [&] { migraphx::program p; auto x = p.add_parameter("x", outer); auto y = p.add_parameter("y", outer); auto one = p.add_literal({inner, {1, 1}}); auto oneb = p.add_instruction(b, one); auto two = p.add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}); auto sum1 = p.add_instruction(migraphx::op::add{}, x, oneb); auto sum2 = p.add_instruction(migraphx::op::add{}, y, two); auto sum3 = p.add_instruction(migraphx::op::add{}, sum1, sum2); p.add_instruction(pass_op{}, sum3); return p; }; migraphx::program p1 = create_program(); p1.compile(simplify_algebra_target{}); migraphx::program p2 = create_program(); EXPECT(p1 == p2); } // TODO: Add test case // TEST_CASE(simplify_add4) void simplify_add4() { migraphx::program p1; { auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p1.add_literal(1); auto two = p1.add_literal(2); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, y); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two); p1.add_instruction(pass_op{}, sum3); } p1.compile(simplify_algebra_target{}); migraphx::program p2; { auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); auto one = p2.add_literal(1); auto two = p2.add_literal(2); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); p2.add_instruction(pass_op{}, sum3); } EXPECT(p1 == p2); } int main(int argc, const char* argv[]) { test::run(argc, argv); }