Commit b5d1db2e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Change Unsqueeze to squeeze to parse_if for trailing dimensions

Default to smaller dimension with trailing 1 case instead of unsqueezing to
the larger dimension.

More analysis on other networks concludes that when putting in two operands to
and IF block, the output should take the smaller of the shapes instead of the larger

Modified tests in onnx_test.cpp to parse to the correct output as well.
parent ef0b52e7
...@@ -48,11 +48,11 @@ inline bool all_but_last_dims_equal(const std::vector<size_t>& lens_a, ...@@ -48,11 +48,11 @@ inline bool all_but_last_dims_equal(const std::vector<size_t>& lens_a,
} }
}; };
void unsqueeze_last_op(module_ref mdl, int index, const std::vector<size_t>& out_shape) void squeeze_last_op(module_ref mdl, int index, const std::vector<size_t>& out_shape)
{ {
auto convert_ins = auto convert_ins =
mdl->insert_instruction(std::prev(mdl->end()), mdl->insert_instruction(std::prev(mdl->end()),
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}), make_op("squeeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(mdl->end())->inputs().at(index)); std::prev(mdl->end())->inputs().at(index));
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), convert_ins); mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), convert_ins);
} }
...@@ -168,14 +168,14 @@ struct parse_if : op_parser<parse_if> ...@@ -168,14 +168,14 @@ struct parse_if : op_parser<parse_if>
auto last_then = then_lens.back(); auto last_then = then_lens.back();
auto last_else = else_lens.back(); auto last_else = else_lens.back();
// Find which dim to unsqueeze // Find which dim to squeeze
if((then_lens.size() < else_lens.size()) && (last_else == 1)) if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
unsqueeze_last_op(then_mdl, i, else_lens); squeeze_last_op(else_mdl, i, else_lens);
} }
else if((then_lens.size() > else_lens.size()) && (last_then == 1)) else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{ {
unsqueeze_last_op(else_mdl, i, then_lens); squeeze_last_op(then_mdl, i, then_lens);
} }
} }
else if(rank_delta > 1) else if(rank_delta > 1)
......
...@@ -2402,12 +2402,12 @@ TEST_CASE(if_else_trailing_one_shape_test) ...@@ -2402,12 +2402,12 @@ TEST_CASE(if_else_trailing_one_shape_test)
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt}); auto broad_rt = then_mod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rt);
then_mod->add_return({broad_rt});
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto broad_re = else_mod->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), re); else_mod->add_return({re});
else_mod->add_return({broad_re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
...@@ -2570,12 +2570,12 @@ TEST_CASE(if_then_trailing_one_shape_test) ...@@ -2570,12 +2570,12 @@ TEST_CASE(if_then_trailing_one_shape_test)
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
auto broad_rt = then_mod->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), rt); then_mod->add_return({rt});
then_mod->add_return({broad_rt});
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re}); auto broad_re = else_mod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), re);
else_mod->add_return({broad_re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
...@@ -2810,18 +2810,18 @@ TEST_CASE(if_then_else_multi_output_shapes_test) ...@@ -2810,18 +2810,18 @@ TEST_CASE(if_then_else_multi_output_shapes_test)
auto x = mm->add_parameter("x", s_trail); auto x = mm->add_parameter("x", s_trail);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = then_mod->add_instruction(migraphx::make_op("add"), x, x); auto rt2 = then_mod->add_instruction(migraphx::make_op("add"), x, x);
then_mod->add_return({rt, rt2}); auto unsqueeze = then_mod->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), rt);
auto unsqueeze2 = then_mod->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), rt2);
then_mod->add_return({unsqueeze, unsqueeze2});
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto re2 = else_mod->add_instruction(migraphx::make_op("sub"), y, l2); auto re2 = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
auto unsqueeze = else_mod->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), re); else_mod->add_return({re, re2});
auto unsqueeze2 =
else_mod->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), re2);
else_mod->add_return({unsqueeze, unsqueeze2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment