Commit 62c746eb authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Simplify parse_if to remove literal and broadcasting for empty branches

Just grab the last output from the non empty branch and use the identity
operator to get the proper shape for the output branch. In the case of
an empty branch (empty tensor, of some type) this tends to mean "Do nothing"
so we're folding the output of the other flow branch here, and thus if somehow,
we do reach this at eval, should throw an error signalling either one of two things

1. Onnx model is invalid
2. The model has run into an error condition with its control flow.

Since IF is an odd operator that can adjust axes, and other operators in a data driven
fashion, this would serve as a check at compile and or/runtime.
parent 260d7aa7
...@@ -125,32 +125,25 @@ struct parse_if : op_parser<parse_if> ...@@ -125,32 +125,25 @@ struct parse_if : op_parser<parse_if>
assert(not(then_lens.empty() and else_lens.empty())); assert(not(then_lens.empty() and else_lens.empty()));
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) { auto handle_empty_branch = [](module_ref& mdl, int index, module_ref& other_mdl) {
shape gen_shape(shape(out_shape.type(), {1}, {0})); auto identity_ins =
auto literal_ins = mdl->add_literal(literal(gen_shape, {0})); mdl->insert_instruction(std::prev(mdl->end()),
auto unsqueeze_ins = mdl->insert_instruction( make_op("identity"),
std::prev(mdl->end()), std::prev(other_mdl->end())->inputs().at(index));
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}), mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), identity_ins);
literal_ins);
auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins);
auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
return out_shape.lens();
}; };
// Handle one empty branch by setting output identical to the other // Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_lens.empty()) if(then_lens.empty())
{ {
then_lens = handle_empty_branch(then_mdl, i, else_out_shape); handle_empty_branch(then_mdl, i, else_mdl);
then_lens = else_lens;
} }
else if(else_lens.empty()) else if(else_lens.empty())
{ {
else_lens = handle_empty_branch(else_mdl, i, then_out_shape); handle_empty_branch(else_mdl, i, then_mdl);
else_lens = then_lens;
} }
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
......
...@@ -2597,23 +2597,15 @@ TEST_CASE(if_then_empty_constant_test) ...@@ -2597,23 +2597,15 @@ TEST_CASE(if_then_empty_constant_test)
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
then_mod->add_return({contig_out});
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re}); else_mod->add_return({re});
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
auto identity_ins = then_mod->add_instruction(migraphx::make_op("identity"), re);
then_mod->add_return({identity_ins});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
...@@ -2634,36 +2626,18 @@ TEST_CASE(if_then_empty_constant_multi_output_test) ...@@ -2634,36 +2626,18 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 = then_mod->add_literal(migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 = then_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins2);
auto broad_ins2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins2);
auto contig_out2 = then_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins2);
then_mod->add_return({contig_out, contig_out2});
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
auto mul = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto mul = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto sub = else_mod->add_instruction(migraphx::make_op("sub"), y, l2); auto sub = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
else_mod->add_return({mul, sub}); else_mod->add_return({mul, sub});
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_literal(migraphx::shape::int64_type);
auto identity_ins = then_mod->add_instruction(migraphx::make_op("identity"), mul);
auto identity_ins2 = then_mod->add_instruction(migraphx::make_op("identity"), sub);
then_mod->add_return({identity_ins, identity_ins2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
...@@ -2690,19 +2664,9 @@ TEST_CASE(if_else_empty_constant_test) ...@@ -2690,19 +2664,9 @@ TEST_CASE(if_else_empty_constant_test)
then_mod->add_return({rt}); then_mod->add_return({rt});
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(s.type()); else_mod->add_literal(s.type());
auto identity_ins = else_mod->add_instruction(migraphx::make_op("identity"), rt);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); else_mod->add_return({identity_ins});
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
else_mod->add_return({contig_out});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
...@@ -2730,27 +2694,11 @@ TEST_CASE(if_else_empty_constant_multi_output_test) ...@@ -2730,27 +2694,11 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
then_mod->add_return({mul, sub}); then_mod->add_return({mul, sub});
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(migraphx::shape::int64_type); else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_literal(migraphx::shape::int64_type); else_mod->add_literal(migraphx::shape::int64_type);
auto identity_ins = else_mod->add_instruction(migraphx::make_op("identity"), mul);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); auto identity_ins2 = else_mod->add_instruction(migraphx::make_op("identity"), sub);
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0})); else_mod->add_return({identity_ins, identity_ins2});
auto unsqueeze_ins = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins);
auto broad_ins = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins);
auto contig_out = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 = else_mod->add_literal(migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 = else_mod->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), literal_ins2);
auto broad_ins2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), unsqueeze_ins2);
auto contig_out2 = else_mod->add_instruction(migraphx::make_op("contiguous"), broad_ins2);
else_mod->add_return({contig_out, contig_out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment