"vscode:/vscode.git/clone" did not exist on "ac232be520d1047eda64d1481fe0a42ef0e4580f"
Commit 7a141a27 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

First attempt at adding proper reshape for then/else modules in parse_if

parent dd6540bd
...@@ -73,36 +73,38 @@ struct parse_if : op_parser<parse_if> ...@@ -73,36 +73,38 @@ struct parse_if : op_parser<parse_if>
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
std::to_string(then_out_shapes.at(0).type()) + " vs " + then_out_shapes.at(0).type_string() + " vs " +
std::to_string(else_out_shapes.at(0).type())); else_out_shapes.at(0).type_string());
} }
// If either argument returns non scalar, promote the scalar to a 1D tensor to meet the if(not then_out_shapes.at(0).scalar() && not else_out_shapes.at(0).scalar())
// shape requirements
// if and only if the first dimension matches
if(not then_out_shapes.at(0).scalar() || not else_out_shapes.at(0).scalar)
{ {
if(then_out_shapes.at(0).scalar()) // First dimension must agree
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0))
{ {
if(then_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0)) MIGRAPHX_THROW("PARSE_IF: " + then_out_shapes.at(0).type_string() + " & " +
{ else_out_shapes.at(0).type_string() +
MIGRAPHX_THROW("PARSE_IF: " + info.name + " are incompatible output shapes for then/cases");
"then out incompatible output shape with else");
}
migraphx::shape s(
then_out_shapes.at(0).type(), {then_out_shapes.at(0).lens().at(0), 1}, {1, 1});
then_mdl->add_outline(s);
} }
else if(else_out_shapes.at(0).scalar())
auto then_out_strides = then_out_shapes.at(0).strides();
auto else_out_strides = else_out_shapes.at(0).strides();
if(then_out_strides.size() > else_out_strides.size())
{
else_mdl->insert_instruction(
std::prev(else_mdl->end()),
migraphx::make_op(
"reshape", {{"dims", {{else_out_shapes.at(0).lens().at(0), 1}, {1, 1}}}}),
std::prev(else_mdl->end())->inputs().front());
}
else if(then_out_strides.size() < else_out_strides.size())
{ {
if(else_out_shapes.at(0).lens().at(0) != else_out_shapes.at(0).lens().at(0)) then_mdl->insert_instruction(
{ std::prev(then_mdl->end()),
MIGRAPHX_THROW("PARSE_IF: " + info.name + migraphx::make_op(
"else out incompatible output shape with then"); "reshape", {{"dims", {{then_out_shapes.at(0).lens().at(0), 1}, {1, 1}}}}),
} std::prev(then_mdl->end())->inputs().front());
migraphx::shape s(
else_out_shapes.at(0).type(), {else_out_shapes.at(0).lens().at(0), 1}, {1, 1});
else_mdl->add_outline(s);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment