Commit 8b08a86f authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

Attempt to fix scalar type and shape interpretation

Comming back to this once I've fixed parse_constant, looks like unallocated empty literals break this right now.
parent 07e1755a
......@@ -69,6 +69,7 @@ struct parse_if : op_parser<parse_if>
assert(then_out_shapes.size() == else_out_shapes.size());
// Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{
......@@ -82,24 +83,45 @@ struct parse_if : op_parser<parse_if>
{
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{
auto convert_ins = std::prev(then_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
then_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
convert_ins,
migraphx::make_op("convert",
{{"target_type", else_out_shapes.at(0).type()}}),
convert_ins->inputs().back());
// then_mdl->replace_return({convert_ins});
}
migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(then_mdl->end())->inputs().front());
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
then_mdl->replace_return({reshape_ins});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
auto convert_ins = std::prev(else_mdl->end());
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
else_out_shapes.at(0).elements() < 1)
{
convert_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("convert",
{{"target_type", then_out_shapes.at(0).type()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({convert_ins});
}
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto reshape_ins = else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(else_mdl->end())->inputs().front());
auto reshape_ins = then_mdl->insert_instruction(
convert_ins, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), convert_ins);
else_mdl->replace_return({reshape_ins});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment