#include #include #include #include #include void run_pass(migraphx::program& p) { migraphx::run_passes(*p.get_main_module(), {migraphx::decompose{}}); } TEST_CASE(dot_add) { migraphx::program p1; { auto* mm1 = p1.get_main_module(); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = mm1->add_instruction(migraphx::make_op("dot"), x, y, z); mm1->add_instruction(migraphx::make_op("identity"), dot); } run_pass(p1); migraphx::program p2; { auto* mm2 = p2.get_main_module(); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, z); mm2->add_instruction(migraphx::make_op("identity"), add); } EXPECT(p1 == p2); } TEST_CASE(dot_add_beta_float) { migraphx::program p1; { auto* mm1 = p1.get_main_module(); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = mm1->add_instruction( migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); mm1->add_instruction(migraphx::make_op("identity"), dot); } run_pass(p1); migraphx::program p2; { auto* mm2 = p2.get_main_module(); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto beta = mm2->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); auto beta_broadcast = mm2->add_instruction( migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); mm2->add_instruction(migraphx::make_op("identity"), add); } EXPECT(p1 == p2); } TEST_CASE(dot_add_beta_half) { migraphx::program p1; { auto* mm1 = p1.get_main_module(); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto dot = mm1->add_instruction( migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); mm1->add_instruction(migraphx::make_op("identity"), dot); } run_pass(p1); migraphx::program p2; { auto* mm2 = p2.get_main_module(); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto dot = mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto beta = mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); auto beta_broadcast = mm2->add_instruction( migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); mm2->add_instruction(migraphx::make_op("identity"), add); } EXPECT(p1 == p2); } TEST_CASE(dot_add_beta_double) { migraphx::program p1; { auto* mm1 = p1.get_main_module(); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto dot = mm1->add_instruction( migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); mm1->add_instruction(migraphx::make_op("identity"), dot); } run_pass(p1); migraphx::program p2; { auto* mm2 = p2.get_main_module(); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto dot = mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto beta = mm2->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); auto beta_broadcast = mm2->add_instruction( migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); mm2->add_instruction(migraphx::make_op("identity"), add); } EXPECT(p1 == p2); } TEST_CASE(dot_add_beta_int) { migraphx::program p1; { auto* mm1 = p1.get_main_module(); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto dot = mm1->add_instruction( migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); mm1->add_instruction(migraphx::make_op("identity"), dot); } migraphx::program p2 = p1; run_pass(p1); EXPECT(p1 == p2); } int main(int argc, const char* argv[]) { test::run(argc, argv); }