Commit c7d194ea authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Handle scalar conversion primarily empty scalars

Need to have this to force the output to be a compatible to both the else/then cases

Still a work in progress
parent 3e924a30
...@@ -69,7 +69,8 @@ struct parse_if : op_parser<parse_if> ...@@ -69,7 +69,8 @@ struct parse_if : op_parser<parse_if>
assert(then_out_shapes.size() == else_out_shapes.size()); assert(then_out_shapes.size() == else_out_shapes.size());
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type()) if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
...@@ -79,6 +80,29 @@ struct parse_if : op_parser<parse_if> ...@@ -79,6 +80,29 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{ {
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{
migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({reshape_ins});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto reshape_ins = else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({reshape_ins});
}
// First dimension must agree // First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0)) if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{ {
...@@ -90,6 +114,7 @@ struct parse_if : op_parser<parse_if> ...@@ -90,6 +114,7 @@ struct parse_if : op_parser<parse_if>
auto then_out_strides = then_out_shapes.at(0).strides(); auto then_out_strides = then_out_shapes.at(0).strides();
auto else_out_strides = else_out_shapes.at(0).strides(); auto else_out_strides = else_out_shapes.at(0).strides();
// Generate compatible output types based on largest dimension with rank 1 tensor
if(then_out_strides.size() > else_out_strides.size()) if(then_out_strides.size() > else_out_strides.size())
{ {
auto reshape_ins = else_mdl->insert_instruction( auto reshape_ins = else_mdl->insert_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment