Commit cd9560ed authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

PR comments from Paul

-make all_but_last_dims_equal func instead of lambda
-rename dim_delta -> rank_delta
-make unsqueeze_last_op func instead of lambda
-Handle multi output cases of changing output instructions
-capture shape at each output shape at start of loop via .at() operator
-replace instances of && with and
parent 48a85620
...@@ -35,6 +35,28 @@ namespace migraphx { ...@@ -35,6 +35,28 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
inline bool all_but_last_dims_equal(const std::vector<size_t>& lens_a,
const std::vector<size_t>& lens_b)
{
if(lens_a.size() <= lens_b.size())
{
return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
}
else
{
return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
}
};
void unsqueeze_last_op(module_ref mdl, int index, const std::vector<size_t>& out_shape)
{
auto convert_ins =
mdl->insert_instruction(std::prev(mdl->end()),
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(mdl->end())->inputs().at(index));
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), convert_ins);
}
struct parse_if : op_parser<parse_if> struct parse_if : op_parser<parse_if>
{ {
std::vector<op_desc> operators() const { return {{"If"}}; } std::vector<op_desc> operators() const { return {{"If"}}; }
...@@ -81,23 +103,26 @@ struct parse_if : op_parser<parse_if> ...@@ -81,23 +103,26 @@ struct parse_if : op_parser<parse_if>
// Add checks for each output shape // Add checks for each output shape
for(int i = 0; i < then_out_shapes.size(); i++) for(int i = 0; i < then_out_shapes.size(); i++)
{ {
const auto& then_out_shape = then_out_shapes.at(i);
const auto& else_out_shape = else_out_shapes.at(i);
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
if(then_out_shapes.at(i).type() != else_out_shapes.at(i).type()) if(then_out_shape.type() != else_out_shape.type())
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
then_out_shapes.at(i).type_string() + " vs " + then_out_shape.type_string() + " vs " +
else_out_shapes.at(i).type_string()); else_out_shape.type_string());
} }
if(not then_out_shapes.at(i).dynamic() && not else_out_shapes.at(i).dynamic()) if(not then_out_shape.dynamic() and not else_out_shape.dynamic())
{ {
auto then_lens = then_out_shapes.at(i).lens(); auto then_lens = then_out_shape.lens();
auto else_lens = else_out_shapes.at(i).lens(); auto else_lens = else_out_shape.lens();
// Throw error if both branches have zero output shapes. Not possible for static // Throw error if both branches have zero output shapes. Not possible for static
// inputs // inputs
if(then_lens.empty() && else_lens.empty()) if(then_lens.empty() and else_lens.empty())
{ {
throw_shapes(); throw_shapes();
} }
...@@ -112,29 +137,17 @@ struct parse_if : op_parser<parse_if> ...@@ -112,29 +137,17 @@ struct parse_if : op_parser<parse_if>
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_lens.empty()) if(then_lens.empty())
{ {
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(i)); then_lens = handle_empty_branch(then_mdl, else_out_shape);
} }
else if(else_lens.empty()) else if(else_lens.empty())
{ {
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(i)); else_lens = handle_empty_branch(else_mdl, then_out_shape);
} }
auto all_but_last_dims_equal = [](const std::vector<size_t>& lens_a,
const std::vector<size_t>& lens_b) {
if(lens_a.size() <= lens_b.size())
{
return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
}
else
{
return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
}
};
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size()))); int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta == 1) if(rank_delta == 1)
{ {
// make sure dims are equivalent in static shapes // make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens)) if(not all_but_last_dims_equal(then_lens, else_lens))
...@@ -145,26 +158,17 @@ struct parse_if : op_parser<parse_if> ...@@ -145,26 +158,17 @@ struct parse_if : op_parser<parse_if>
auto last_then = then_lens.back(); auto last_then = then_lens.back();
auto last_else = else_lens.back(); auto last_else = else_lens.back();
auto unsqueeze_last_op = [](module_ref& mdl,
const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(mdl->end())->inputs().front());
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
// Find which dim to unsqueeze // Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1)) if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
unsqueeze_last_op(then_mdl, else_lens); unsqueeze_last_op(then_mdl, i, else_lens);
} }
else if((then_lens.size() > else_lens.size()) && (last_then == 1)) else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{ {
unsqueeze_last_op(else_mdl, then_lens); unsqueeze_last_op(else_mdl, i, then_lens);
} }
} }
else if(dim_delta > 1) else if(rank_delta > 1)
{ {
throw_shapes(); throw_shapes();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment