Commit 98bf0b8b authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Force scalars to be the same shape when comming out of an if then/else block

parent 7bc8dd35
......@@ -77,26 +77,32 @@ struct parse_if : op_parser<parse_if>
std::to_string(else_out_shapes.at(0).type()));
}
// Need to check static shapes result
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
// If either argument returns non scalar, promote the scalar to a 1D tensor to meet the
// shape requirements
// if and only if the first dimension matches
if(not then_out_shapes.at(0).scalar() || not else_out_shapes.at(0).scalar)
{
if(then_out_shapes.at(0).scalar())
{
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
then_out_shapes.at(0).strides().at(0) != 1)
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"then out incompatible output shape with else");
}
migraphx::shape s(
then_out_shapes.at(0).type(), {then_out_shapes.at(0).lens().at(0), 1}, {1, 1});
then_mdl->add_outline(s);
}
else if(else_out_shapes.at(0).scalar())
{
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
else_out_shapes.at(0).strides().at(0) == 1)
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"else out incompatible output shape with then");
}
migraphx::shape s(
else_out_shapes.at(0).type(), {else_out_shapes.at(0).lens().at(0), 1}, {1, 1});
else_mdl->add_outline(s);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment