Commit 3e067a8a authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix trailing 1 shape mismatch with unsqueeze instead of outline

Fixes cases for trailing one testcases
parent f24c65c3
...@@ -129,15 +129,18 @@ struct parse_if : op_parser<parse_if> ...@@ -129,15 +129,18 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
// Find which dim to pad // Find which dim to unsqueeze
if(then_shape.size() < else_shape.size()) if(then_shape.size() < else_shape.size())
{ {
auto last_else = *(--(else_shape.end())); auto last_else = *(--(else_shape.end()));
if(last_else <= 1) if(last_else <= 1)
{ {
auto convert_ins = then_mdl->add_outline(else_out_shapes.at(0)); auto convert_ins = then_mdl->add_instruction(
make_op("unsqueeze", {{"axes", {else_shape.size() - 1}}}),
--(--then_mdl->end()));
then_mdl->replace_return({convert_ins}); then_mdl->replace_return({convert_ins});
then_mdl->remove_instruction({--convert_ins});
} }
} }
else else
...@@ -146,8 +149,11 @@ struct parse_if : op_parser<parse_if> ...@@ -146,8 +149,11 @@ struct parse_if : op_parser<parse_if>
if(last_then <= 1) if(last_then <= 1)
{ {
auto convert_ins = else_mdl->add_outline(then_out_shapes.at(0)); auto convert_ins = else_mdl->add_instruction(
make_op("unsqueeze", {{"axes", {then_shape.size() - 1}}}),
--(--else_mdl->end()));
else_mdl->replace_return({convert_ins}); else_mdl->replace_return({convert_ins});
else_mdl->remove_instruction({--convert_ins});
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment