Commit 7bc8dd35 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix parse_if cases for output result of if

The onnx spec mentions that the output shape of the resulting then/else branches must share the same type, but not the same shape. The only requirement is that the first dimension is compatible should one of the inputs have rank of one.

Without this we prematurely assert when an if is requred on the following case

int64, {1234, 1} (this is a 1 rank tensor)
int64, {1234}  (this result is scalar)
parent 6040f741
...@@ -65,13 +65,39 @@ struct parse_if : op_parser<parse_if> ...@@ -65,13 +65,39 @@ struct parse_if : op_parser<parse_if>
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(), assert(then_out_shapes.size() == else_out_shapes.size());
else_out_shapes.begin(),
else_out_shapes.end())) // Must have the same type for both if/else blocks by onnx spec
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output shapes!"); " then and else sub_grahps must have same output type! " +
std::to_string(then_out_shapes.at(0).type()) + " vs " +
std::to_string(else_out_shapes.at(0).type()));
}
// Need to check static shapes result
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
if(then_out_shapes.at(0).scalar())
{
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
then_out_shapes.at(0).strides().at(0) != 1)
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"then out incompatible output shape with else");
}
}
else if(else_out_shapes.at(0).scalar())
{
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
else_out_shapes.at(0).strides().at(0) == 1)
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"else out incompatible output shape with then");
}
}
} }
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment