"vscode:/vscode.git/clone" did not exist on "3977fa99321601f5dca70fba24b7c5ab2041dfbb"
Commit c7d194ea authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Handle scalar conversion primarily empty scalars

Need to have this to force the output to be a compatible to both the else/then cases

Still a work in progress
parent 3e924a30
...@@ -69,7 +69,8 @@ struct parse_if : op_parser<parse_if> ...@@ -69,7 +69,8 @@ struct parse_if : op_parser<parse_if>
assert(then_out_shapes.size() == else_out_shapes.size()); assert(then_out_shapes.size() == else_out_shapes.size());
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type()) if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type() &&
(then_out_shapes.at(0).elements() > 0) && (else_out_shapes.at(0).elements() > 0))
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
...@@ -79,6 +80,29 @@ struct parse_if : op_parser<parse_if> ...@@ -79,6 +80,29 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{ {
if(then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
{
migraphx::shape s{else_out_shapes.at(0).type(),
else_out_shapes.at(0).lens(),
else_out_shapes.at(0).strides()};
auto reshape_ins = then_mdl->insert_instruction(
std::prev(then_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(then_mdl->end())->inputs().front());
then_mdl->replace_return({reshape_ins});
}
else if(not then_out_shapes.at(0).scalar() && else_out_shapes.at(0).scalar())
{
migraphx::shape s{then_out_shapes.at(0).type(),
then_out_shapes.at(0).lens(),
then_out_shapes.at(0).strides()};
auto reshape_ins = else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op("broadcast", {{"out_lens", s.lens()}}),
std::prev(else_mdl->end())->inputs().front());
else_mdl->replace_return({reshape_ins});
}
// First dimension must agree // First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0)) if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{ {
...@@ -90,6 +114,7 @@ struct parse_if : op_parser<parse_if> ...@@ -90,6 +114,7 @@ struct parse_if : op_parser<parse_if>
auto then_out_strides = then_out_shapes.at(0).strides(); auto then_out_strides = then_out_shapes.at(0).strides();
auto else_out_strides = else_out_shapes.at(0).strides(); auto else_out_strides = else_out_shapes.at(0).strides();
// Generate compatible output types based on largest dimension with rank 1 tensor
if(then_out_strides.size() > else_out_strides.size()) if(then_out_strides.size() > else_out_strides.size())
{ {
auto reshape_ins = else_mdl->insert_instruction( auto reshape_ins = else_mdl->insert_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment