"docs/en_US/RemoteMachineMode.md" did not exist on "bbf4760ca3ae7bab222f453f446e2b152a47fbbe"
Commit b3d7872a authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Attempt to fix scalar type and shape interpretation

Comming back to this once I've fixed parse_constant, looks like unallocated empty literals break this right now.
parent c7d194ea
...@@ -69,6 +69,7 @@ struct parse_if : op_parser<parse_if> ...@@ -69,6 +69,7 @@ struct parse_if : op_parser<parse_if>
assert(then_out_shapes.size() == else_out_shapes.size()); assert(then_out_shapes.size() == else_out_shapes.size());
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() && if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0)) (then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{ {
...@@ -82,24 +83,45 @@ struct parse_if : op_parser<parse_if> ...@@ -82,24 +83,45 @@ struct parse_if : op_parser<parse_if>
{ {
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar()) if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(then_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
then_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
}
migraphx::shape s{else_out_shapes.at(0).type(), migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(), else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()}; else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()), convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({reshape_ins}); then_mdl->replace_return({reshape_ins});
} }
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar()) else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{ {
auto convert_ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("convert",
{{"target_type", then_out_shapes.at(0).type()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
migraphx::shape s{then_out_shapes.at(0).type(), migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(), then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()}; then_out_shapes.at(0).strides()};
auto reshape_ins = else_mdl->insert_instruction( auto reshape_ins = then_mdl->insert_instruction(
std::prev(else_mdl->end()), convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({reshape_ins}); else_mdl->replace_return({reshape_ins});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment