"src/include/Sequence.hpp" did not exist on "fd8de384170d6100a837b19e37139665c89e2054"
Commit 2b182551 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Be strict on types comming out of then/else branches

Had errors with the parsing of empty constants. Reverting this set of changes.
parent b3d7872a
......@@ -70,8 +70,7 @@ struct parse_if : op_parser<parse_if>
// Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " +
......@@ -81,46 +80,20 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
// unsqueeze up to a 1d vector for now
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{
auto convert_ins = std::prev(then_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
then_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
}
migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()};
auto ins = std::prev(then_mdl->end());
auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
then_mdl->replace_return({reshape_ins});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
auto convert_ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("convert",
{{"target_type", then_out_shapes.at(0).type()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto ins = std::prev(else_mdl->end());
auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
ins, migraphx::make_op("unsqueeze", {{"axes", {0}}}), ins);
else_mdl->replace_return({reshape_ins});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment