Commit 2b182551 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Be strict on types comming out of then/else branches

Had errors with the parsing of empty constants. Reverting this set of changes.
parent b3d7872a
...@@ -70,8 +70,7 @@ struct parse_if : op_parser<parse_if> ...@@ -70,8 +70,7 @@ struct parse_if : op_parser<parse_if>
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars // Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() && if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
...@@ -81,46 +80,20 @@ struct parse_if : op_parser<parse_if> ...@@ -81,46 +80,20 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{ {
// unsqueeze up to a 1d vector for now
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar()) if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(then_mdl->end()); auto ins = std::prev(then_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
then_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
}
migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins); ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
then_mdl->replace_return({reshape_ins}); then_mdl->replace_return({reshape_ins});
} }
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar()) else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(else_mdl->end()); auto ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("convert",
{{"target_type", then_out_shapes.at(0).type()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins); ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
else_mdl->replace_return({reshape_ins}); else_mdl->replace_return({reshape_ins});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment