Commit b3d7872a authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Attempt to fix scalar type and shape interpretation

Comming back to this once I've fixed parse_constant, looks like unallocated empty literals break this right now.
parent c7d194ea
...@@ -69,6 +69,7 @@ struct parse_if : op_parser<parse_if> ...@@ -69,6 +69,7 @@ struct parse_if : op_parser<parse_if>
assert(then_out_shapes.size() == else_out_shapes.size()); assert(then_out_shapes.size() == else_out_shapes.size());
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() && if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0)) (then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{ {
...@@ -82,24 +83,45 @@ struct parse_if : op_parser<parse_if> ...@@ -82,24 +83,45 @@ struct parse_if : op_parser<parse_if>
{ {
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar()) if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(then_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
then_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
}
migraphx::shape s{else_out_shapes.at(0).type(), migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(), else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()}; else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()), convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({reshape_ins}); then_mdl->replace_return({reshape_ins});
} }
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar()) else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("convert",
{{"target_type", then_out_shapes.at(0).type()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
migraphx::shape s{then_out_shapes.at(0).type(), migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(), then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()}; then_out_shapes.at(0).strides()};
auto reshape_ins = else_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
std::prev(else_mdl->end()), convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({reshape_ins}); else_mdl->replace_return({reshape_ins});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment