Commit 34aeda23 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix tidy issue and clean up logic

parent 372206b3
...@@ -115,22 +115,21 @@ struct parse_if : op_parser<parse_if> ...@@ -115,22 +115,21 @@ struct parse_if : op_parser<parse_if>
{ {
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0)); else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0));
} }
else
{
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size()))); int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta <= 1) if(dim_delta == 1)
{ {
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_A, auto all_but_last_dims_equal = [](std::vector<size_t>& lens_a,
std::vector<size_t>& lens_B) { std::vector<size_t>& lens_b) {
if(lens_A.size() <= lens_B.size()) if(lens_a.size() <= lens_b.size())
{ {
return equal(lens_A.begin(), lens_A.end(), lens_B.begin()); return equal(lens_a.begin(), lens_a.end(), lens_b.begin());
} }
else else
{ {
return equal(lens_B.begin(), lens_B.end(), lens_A.begin()); return equal(lens_b.begin(), lens_b.end(), lens_a.begin());
} }
}; };
...@@ -140,8 +139,7 @@ struct parse_if : op_parser<parse_if> ...@@ -140,8 +139,7 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
auto unsqueeze_last_op = [](module_ref& mdl, auto unsqueeze_last_op = [](module_ref& mdl, const std::vector<size_t>& out_shape) {
const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction( auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}), make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end()))); std::prev(std::prev(mdl->end())));
...@@ -162,12 +160,11 @@ struct parse_if : op_parser<parse_if> ...@@ -162,12 +160,11 @@ struct parse_if : op_parser<parse_if>
unsqueeze_last_op(else_mdl, then_lens); unsqueeze_last_op(else_mdl, then_lens);
} }
} }
else else if(dim_delta > 1)
{ {
throw_shapes(); throw_shapes();
} }
} }
}
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
auto out_s = if_ret->get_shape(); auto out_s = if_ret->get_shape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment