#include #include #include #include #include #include #include #include void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::inline_module{}, migraphx::dead_code_elimination{}}); } TEST_CASE(cannot_inline_both) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sd{migraphx::shape::float_type, {2, 3}}; auto x = mm->add_parameter("x", sd); std::vector one(sd.elements(), 1); std::vector two(sd.elements(), 2); auto* then_smod = p.create_module("then_smod"); auto l1 = then_smod->add_literal(migraphx::literal{sd, one}); auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1); then_smod->add_return({r1}); auto* else_smod = p.create_module("else_smod"); auto l2 = else_smod->add_literal(migraphx::literal{sd, two}); auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2); else_smod->add_return({r2}); migraphx::shape s_cond{migraphx::shape::bool_type, {1}}; auto cond = mm->add_parameter("cond", s_cond); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod}); mm->add_return({ret}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_program()); } TEST_CASE(cannot_inline_one) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; migraphx::shape s{migraphx::shape::float_type, {5}}; auto cond = mm->add_parameter("cond", cond_s); auto x = mm->add_parameter("x", s); auto* then_mod = p.create_module("If_0_if"); std::vector data1 = {1, 2, 3, 4, 5}; auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); then_mod->add_return({l1, x}); auto* else_mod = p.create_module("If_0_else"); std::vector data2 = {5, 4, 3, 2, 1}; auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2); else_mod->add_return({s2, l2}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); mm->add_return({ret}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_program()); } TEST_CASE(inline_if_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sc{migraphx::shape::bool_type, {1}}; auto cond = mm->add_literal(migraphx::literal(sc, {1})); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; std::vector ones(s.elements(), 1.0f); auto l1 = mm->add_literal(s, ones); std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; auto l2 = mm->add_literal(s, rand); auto x = mm->add_parameter("x", s); auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x); auto y = mm->add_parameter("y", s); auto* then_mod = p.create_module("If_5_if"); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, sm); then_mod->add_outline(s); then_mod->add_return({rt}); auto* else_mod = p.create_module("If_5_else"); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); else_mod->add_return({re}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; std::vector ones(s.elements(), 1.0f); auto l1 = mm->add_literal(s, ones); std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; mm->add_literal(s, rand); auto x = mm->add_parameter("x", s); auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x); mm->add_parameter("y", s); auto r = mm->add_instruction(migraphx::make_op("add"), x, sm); mm->add_return({r}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } TEST_CASE(inline_else_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sc{migraphx::shape::bool_type, {1}}; auto cond = mm->add_literal(migraphx::literal(sc, {0})); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; std::vector ones(s.elements(), 1.0f); auto l1 = mm->add_literal(s, ones); std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; auto l2 = mm->add_literal(s, rand); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto* then_mod = p.create_module("If_5_if"); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); then_mod->add_return({rt}); auto* else_mod = p.create_module("If_5_else"); else_mod->add_parameter("e", s); else_mod->add_literal(migraphx::literal(s, ones)); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); else_mod->add_return({re}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); mm->add_return({r}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; std::vector ones(s.elements(), 1.0f); mm->add_literal(s, ones); std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; auto l2 = mm->add_literal(s, rand); mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); mm->add_parameter("e", s); auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2); mm->add_return({r}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } TEST_CASE(if_recursive_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; std::vector datax = {1, 2, 3, 4, 5, 6}; std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; auto lx = mm->add_literal(migraphx::literal(xs, datax)); auto ly = mm->add_literal(migraphx::literal(ys, datay)); auto cond = mm->add_literal(migraphx::literal(cond_s, {0})); auto x1 = mm->add_parameter("x1", xs); auto x2 = mm->add_parameter("x2", xs); auto y2 = mm->add_parameter("y2", ys); auto cond1 = mm->add_parameter("cond", cond_s); auto* then_mod = p.create_module("If_5_if"); auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx); then_mod->add_return({a1, l1}); auto* then_mod1 = p.create_module("If_6_if"); auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); then_mod1->add_return({a11, l11}); auto* else_mod1 = p.create_module("If_6_else"); auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); else_mod1->add_return({l21, a21}); auto* else_mod = p.create_module("If_5_else"); auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); auto a3 = else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2); else_mod->add_return({l2, a3}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; std::vector datax = {1, 2, 3, 4, 5, 6}; std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; auto lx = mm->add_literal(migraphx::literal(xs, datax)); auto ly = mm->add_literal(migraphx::literal(ys, datay)); mm->add_parameter("x1", xs); auto x2 = mm->add_parameter("x2", xs); auto y2 = mm->add_parameter("y2", ys); auto cond1 = mm->add_parameter("cond", cond_s); auto* then_mod1 = p.create_module("If_6_if"); auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); then_mod1->add_return({a11, l11}); auto* else_mod1 = p.create_module("If_6_else"); auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); else_mod1->add_return({l21, a21}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } TEST_CASE(if_recursive_cond0_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; std::vector datax = {1, 2, 3, 4, 5, 6}; std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; auto lx = mm->add_literal(migraphx::literal(xs, datax)); auto ly = mm->add_literal(migraphx::literal(ys, datay)); auto cond = mm->add_literal(migraphx::literal(cond_s, {0})); auto x1 = mm->add_parameter("x1", xs); auto x2 = mm->add_parameter("x2", xs); auto y2 = mm->add_parameter("y2", ys); auto* then_mod = p.create_module("If_5_if"); auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx); then_mod->add_return({a1, l1}); auto* then_mod1 = p.create_module("If_6_if"); auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); then_mod1->add_return({a11, l11}); auto* else_mod1 = p.create_module("If_6_else"); auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); else_mod1->add_return({l21, a21}); auto* else_mod = p.create_module("If_5_else"); auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); auto a3 = else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2); else_mod->add_return({l2, a3}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; std::vector datax = {1, 2, 3, 4, 5, 6}; std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; mm->add_literal(migraphx::literal(xs, datax)); auto ly = mm->add_literal(migraphx::literal(ys, datay)); mm->add_parameter("x1", xs); mm->add_parameter("x2", xs); auto y2 = mm->add_parameter("y2", ys); auto m = mm->add_instruction(migraphx::make_op("mul"), y2, ly); mm->add_return({m}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } TEST_CASE(inline_tuple_true_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sc{migraphx::shape::bool_type, {1}}; auto cond = mm->add_literal(migraphx::literal(sc, {1})); migraphx::shape sd{migraphx::shape::float_type, {1}}; auto l1 = mm->add_literal(migraphx::literal(sd, {1})); auto l2 = mm->add_literal(migraphx::literal(sd, {2})); auto l3 = mm->add_literal(migraphx::literal(sd, {3})); migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; auto x = mm->add_parameter("x", sx); auto y = mm->add_parameter("y", sy); auto* then_mod = p.create_module("If_6_if"); auto m1 = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto m2 = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); then_mod->add_return({add0, mul0}); auto* else_mod = p.create_module("If_6_else"); auto me1 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto me2 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); else_mod->add_return({mul1, add1}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r0, r1}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sd{migraphx::shape::float_type, {1}}; auto l1 = mm->add_literal(migraphx::literal(sd, {1})); auto l2 = mm->add_literal(migraphx::literal(sd, {2})); mm->add_literal(migraphx::literal(sd, {3})); migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; auto x = mm->add_parameter("x", sx); auto y = mm->add_parameter("y", sy); auto m1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); auto add = mm->add_instruction(migraphx::make_op("add"), x, m1); auto m2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2); mm->add_return({add, mul}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } TEST_CASE(inline_tuple_false_test) { auto create_program = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sc{migraphx::shape::bool_type, {1}}; auto cond = mm->add_literal(migraphx::literal(sc, {0})); migraphx::shape sd{migraphx::shape::float_type, {1}}; auto l1 = mm->add_literal(migraphx::literal(sd, {1})); auto l2 = mm->add_literal(migraphx::literal(sd, {2})); auto l3 = mm->add_literal(migraphx::literal(sd, {3})); migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; auto x = mm->add_parameter("x", sx); auto y = mm->add_parameter("y", sy); auto* then_mod = p.create_module("If_6_if"); auto m1 = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto m2 = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); then_mod->add_return({add0, mul0}); auto* else_mod = p.create_module("If_6_else"); auto me1 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto me2 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); else_mod->add_return({mul1, add1}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); mm->add_return({r0, r1}); return p; }; auto create_inline = [] { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sd{migraphx::shape::float_type, {1}}; mm->add_literal(migraphx::literal(sd, {1})); mm->add_literal(migraphx::literal(sd, {2})); auto l3 = mm->add_literal(migraphx::literal(sd, {3})); migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; auto x = mm->add_parameter("x", sx); auto y = mm->add_parameter("y", sy); auto m1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1); auto m2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); auto add = mm->add_instruction(migraphx::make_op("add"), y, m2); mm->add_return({mul, add}); return p; }; auto p = create_program(); run_pass(p); EXPECT(p == create_inline()); } int main(int argc, const char* argv[]) { test::run(argc, argv); }