"experiments/pyexps/farmem.py" did not exist on "cc179404dde0c2787c6e108a12745d7c8f1a1dcc"
Commit d799e44e authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

Force scalars to be the same shape when comming out of an if then/else block

parent db6e3a5c
......@@ -77,26 +77,32 @@ struct parse_if : op_parser<parse_if>
std::to_string(else_out_shapes.at(0).type()));
}
// Need to check static shapes result
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
// If either argument returns non scalar, promote the scalar to a 1D tensor to meet the
// shape requirements
// if and only if the first dimension matches
if(not then_out_shapes.at(0).scalar() || not else_out_shapes.at(0).scalar)
{
if(then_out_shapes.at(0).scalar())
{
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
then_out_shapes.at(0).strides().at(0) != 1)
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"then out incompatible output shape with else");
}
migraphx::shape s(
then_out_shapes.at(0).type(), {then_out_shapes.at(0).lens().at(0), 1}, {1, 1});
then_mdl->add_outline(s);
}
else if(else_out_shapes.at(0).scalar())
{
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0) ||
else_out_shapes.at(0).strides().at(0) == 1)
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
"else out incompatible output shape with then");
}
migraphx::shape s(
else_out_shapes.at(0).type(), {else_out_shapes.at(0).lens().at(0), 1}, {1, 1});
else_mdl->add_outline(s);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment