Commit 395e1170 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Addressing review comments from Umang

- rename then/else_shapes -> then/else_lens
- change assert for shape size to reusing throw()
- return lengths in handle_empty_branch
- use handle_empty_branch return value for lens update
- use std::prev() instead of predecriment operator for modifying nodes
parent 1120ed2f
...@@ -68,7 +68,15 @@ struct parse_if : op_parser<parse_if> ...@@ -68,7 +68,15 @@ struct parse_if : op_parser<parse_if>
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
assert(then_out_shapes.size() == else_out_shapes.size()); auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes ");
};
if(then_out_shapes.size() != else_out_shapes.size())
{
throw_shapes();
}
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars // Add exception for empty constant scalars
...@@ -82,15 +90,11 @@ struct parse_if : op_parser<parse_if> ...@@ -82,15 +90,11 @@ struct parse_if : op_parser<parse_if>
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{ {
auto then_shape = then_out_shapes.at(0).lens(); auto then_lens = then_out_shapes.at(0).lens();
auto else_shape = else_out_shapes.at(0).lens(); auto else_lens = else_out_shapes.at(0).lens();
auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes ");
};
// Throw error if both branches have zero output shapes. Not possible for static inputs // Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_shape.empty() && else_shape.empty()) if(then_lens.empty() && else_lens.empty())
{ {
throw_shapes(); throw_shapes();
} }
...@@ -98,30 +102,29 @@ struct parse_if : op_parser<parse_if> ...@@ -98,30 +102,29 @@ struct parse_if : op_parser<parse_if>
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) { auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
auto outline_ins = mdl->add_outline(out_shape); auto outline_ins = mdl->add_outline(out_shape);
mdl->replace_return({outline_ins}); mdl->replace_return({outline_ins});
return out_shape.lens();
}; };
// Handle one empty branch by setting output identical to the other // Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_shape.empty()) if(then_lens.empty())
{ {
handle_empty_branch(then_mdl, else_out_shapes.at(0)); then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(0));
then_shape = else_shape;
} }
else if(else_shape.empty()) else if(else_lens.empty())
{ {
handle_empty_branch(else_mdl, then_out_shapes.at(0)); else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0));
else_shape = then_shape;
} }
else else
{ {
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_shape.size() - else_shape.size()))); int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta <= 1) if(dim_delta <= 1)
{ {
// make sure dims are equivalent in static shapes // make sure dims are equivalent in static shapes
if(not equal(then_shape.begin(), then_shape.end(), else_shape.begin()) && if(not equal(then_lens.begin(), then_lens.end(), else_lens.begin()) &&
not equal(else_shape.begin(), else_shape.end(), then_shape.begin())) not equal(else_lens.begin(), else_lens.end(), then_lens.begin()))
{ {
throw_shapes(); throw_shapes();
} }
...@@ -130,22 +133,22 @@ struct parse_if : op_parser<parse_if> ...@@ -130,22 +133,22 @@ struct parse_if : op_parser<parse_if>
const std::vector<size_t>& out_shape) { const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction( auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}), make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
--(--mdl->end())); std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins}); mdl->replace_return({convert_ins});
mdl->remove_instruction({--convert_ins}); mdl->remove_instruction({std::prev(convert_ins)});
}; };
auto last_then = *(--(then_shape.end())); auto last_then = *(std::prev(then_lens.end()));
auto last_else = *(--(else_shape.end())); auto last_else = *(std::prev(else_lens.end()));
// Find which dim to unsqueeze // Find which dim to unsqueeze
if((then_shape.size() < else_shape.size()) && (last_else == 1)) if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
unsqueeze_last_op(then_mdl, else_shape); unsqueeze_last_op(then_mdl, else_lens);
} }
else if((then_shape.size() > else_shape.size()) && (last_then == 1)) else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{ {
unsqueeze_last_op(else_mdl, then_shape); unsqueeze_last_op(else_mdl, then_lens);
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment