Commit baf55e7f authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix another tidy issue with complexity

Just clean up the logic on this so we're not 5 levels deep on ifs.
parent 1e90dfd7
...@@ -129,25 +129,24 @@ struct parse_if : op_parser<parse_if> ...@@ -129,25 +129,24 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
auto unsqueeze_last_op = [](module_ref& mdl, std::vector<size_t>& out_shape) { auto unsqueeze_last_op = [](module_ref& mdl,
auto last_else = *(--(out_shape.end())); const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
if(last_else <= 1) make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
{ --(--mdl->end()));
auto convert_ins = mdl->add_instruction( mdl->replace_return({convert_ins});
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}), mdl->remove_instruction({--convert_ins});
--(--mdl->end()));
mdl->replace_return({convert_ins});
mdl->remove_instruction({--convert_ins});
}
}; };
auto last_then = *(--(then_shape.end()));
auto last_else = *(--(else_shape.end()));
// Find which dim to unsqueeze // Find which dim to unsqueeze
if(then_shape.size() < else_shape.size()) if((then_shape.size() < else_shape.size()) && (last_else == 1))
{ {
unsqueeze_last_op(then_mdl, else_shape); unsqueeze_last_op(then_mdl, else_shape);
} }
else else if((then_shape.size() > else_shape.size()) && (last_then == 1))
{ {
unsqueeze_last_op(else_mdl, then_shape); unsqueeze_last_op(else_mdl, then_shape);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment