Commit 395e1170 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Addressing review comments from Umang

- rename then/else_shapes -> then/else_lens
- change assert for shape size to reusing throw()
- return lengths in handle_empty_branch
- use handle_empty_branch return value for lens update
- use std::prev() instead of predecriment operator for modifying nodes
parent 1120ed2f
......@@ -68,7 +68,15 @@ struct parse_if : op_parser<parse_if>
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
assert(then_out_shapes.size() == else_out_shapes.size());
auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes ");
};
if(then_out_shapes.size() != else_out_shapes.size())
{
throw_shapes();
}
// Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
......@@ -82,15 +90,11 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
auto then_shape = then_out_shapes.at(0).lens();
auto else_shape = else_out_shapes.at(0).lens();
auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes ");
};
auto then_lens = then_out_shapes.at(0).lens();
auto else_lens = else_out_shapes.at(0).lens();
// Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_shape.empty() && else_shape.empty())
if(then_lens.empty() && else_lens.empty())
{
throw_shapes();
}
......@@ -98,30 +102,29 @@ struct parse_if : op_parser<parse_if>
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
auto outline_ins = mdl->add_outline(out_shape);
mdl->replace_return({outline_ins});
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_shape.empty())
if(then_lens.empty())
{
handle_empty_branch(then_mdl, else_out_shapes.at(0));
then_shape = else_shape;
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(0));
}
else if(else_shape.empty())
else if(else_lens.empty())
{
handle_empty_branch(else_mdl, then_out_shapes.at(0));
else_shape = then_shape;
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0));
}
else
{
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_shape.size() - else_shape.size())));
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta <= 1)
{
// make sure dims are equivalent in static shapes
if(not equal(then_shape.begin(), then_shape.end(), else_shape.begin()) &&
not equal(else_shape.begin(), else_shape.end(), then_shape.begin()))
if(not equal(then_lens.begin(), then_lens.end(), else_lens.begin()) &&
not equal(else_lens.begin(), else_lens.end(), then_lens.begin()))
{
throw_shapes();
}
......@@ -130,22 +133,22 @@ struct parse_if : op_parser<parse_if>
const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
--(--mdl->end()));
std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins});
mdl->remove_instruction({--convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
auto last_then = *(--(then_shape.end()));
auto last_else = *(--(else_shape.end()));
auto last_then = *(std::prev(then_lens.end()));
auto last_else = *(std::prev(else_lens.end()));
// Find which dim to unsqueeze
if((then_shape.size() < else_shape.size()) && (last_else == 1))
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, else_shape);
unsqueeze_last_op(then_mdl, else_lens);
}
else if((then_shape.size() > else_shape.size()) && (last_then == 1))
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, then_shape);
unsqueeze_last_op(else_mdl, then_lens);
}
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment