Commit 638ff250 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Refactor parse_if tests to use add_ vs insert_ instructions and fix returns

Change things in the test cases so that we're not just replicating what
the code does but use add_instruction to dictate what we should expect
for the output of fixing empty const cases.

Had to also switch an insert to add of a literal in the empty case to achieve
this in parse_if as well.

Moved the return instructions to the end of each subgraph to also fix readability
of each test.
parent b5d1db2e
......@@ -127,8 +127,7 @@ struct parse_if : op_parser<parse_if>
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0}));
auto literal_ins =
mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
auto literal_ins = mdl->add_literal(literal(gen_shape, {0}));
auto unsqueeze_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
......
......@@ -2598,23 +2598,17 @@ TEST_CASE(if_then_empty_constant_test)
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
auto then_lit = then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_return({then_lit});
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
then_mod->insert_literal((std::prev(then_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(0), contig_out);
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
then_mod->add_return({contig_out});
auto* else_mod = p.create_module("If_4_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
......@@ -2642,39 +2636,28 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto* then_mod = p.create_module("If_4_if");
auto lit = then_mod->add_literal(migraphx::shape::int64_type);
auto lit2 = then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_return({lit, lit2});
then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
then_mod->insert_literal((std::prev(then_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(0), contig_out);
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 =
then_mod->insert_literal(std::prev(then_mod->end()), migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins2);
auto broad_ins2 =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins2);
auto contig_out2 = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins2);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(1), contig_out2);
auto literal_ins2 = then_mod->add_literal(migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins2);
auto broad_ins2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins2);
auto contig_out2 = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins2);
then_mod->add_return({contig_out, contig_out2});
auto* else_mod = p.create_module("If_4_else");
auto mul = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
......@@ -2707,23 +2690,19 @@ TEST_CASE(if_else_empty_constant_test)
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_4_else");
auto else_lit = else_mod->add_literal(s.type());
else_mod->add_return({else_lit});
else_mod->add_literal(s.type());
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
else_mod->insert_literal((std::prev(else_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(0), contig_out);
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
else_mod->add_return({contig_out});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
......@@ -2752,39 +2731,26 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
auto* else_mod = p.create_module("If_4_else");
auto lit = else_mod->add_literal(migraphx::shape::int64_type);
auto lit2 = else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_return({lit, lit2});
else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
else_mod->insert_literal((std::prev(else_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(0), contig_out);
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 =
else_mod->insert_literal(std::prev(else_mod->end()), migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins2);
auto broad_ins2 =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins2);
auto contig_out2 = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins2);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(1), contig_out2);
auto literal_ins2 = else_mod->add_literal(migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins2);
auto broad_ins2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins2);
auto contig_out2 = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins2);
else_mod->add_return({contig_out, contig_out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment