Commit 1e90dfd7 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix tidy issues and cleanup branches with lambdas

- use empty() instead of size() == 0 for checking each condition
- Make each branch a lambda instead of repeating code.
parent 9d174e76
...@@ -90,29 +90,29 @@ struct parse_if : op_parser<parse_if> ...@@ -90,29 +90,29 @@ struct parse_if : op_parser<parse_if>
}; };
// Throw error if both branches have zero output shapes. Not possible for static inputs // Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_shape.size() == 0 && else_shape.size() == 0) if(then_shape.empty() && else_shape.empty())
{ {
throw_shapes(); throw_shapes();
} }
auto handle_empty_branch = [](module_ref& mdl, std::vector<size_t>& out_shape) {
auto convert_ins =
mdl->insert_instruction(--mdl->end(),
make_op("multibroadcast", {{"out_lens", out_shape}}),
{--(--mdl->end())});
mdl->replace_return({convert_ins});
};
// Handle one empty branch by setting output identical to the other // Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_shape.size() == 0) if(then_shape.empty())
{ {
auto convert_ins = then_mdl->insert_instruction( handle_empty_branch(then_mdl, else_shape);
--else_mdl->end(),
make_op("multibroadcast", {{"out_lens", else_shape}}),
{--(--then_mdl->end())});
then_mdl->replace_return({convert_ins});
then_shape = else_shape; then_shape = else_shape;
} }
else if(else_shape.size() == 0) else if(else_shape.empty())
{ {
auto convert_ins = else_mdl->insert_instruction( handle_empty_branch(else_mdl, then_shape);
--else_mdl->end(),
make_op("multibroadcast", {{"out_lens", then_shape}}),
{--(--else_mdl->end())});
else_mdl->replace_return({convert_ins});
else_shape = then_shape; else_shape = then_shape;
} }
else else
...@@ -129,32 +129,27 @@ struct parse_if : op_parser<parse_if> ...@@ -129,32 +129,27 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
// Find which dim to unsqueeze auto unsqueeze_last_op = [](module_ref& mdl, std::vector<size_t>& out_shape) {
if(then_shape.size() < else_shape.size()) auto last_else = *(--(out_shape.end()));
{
auto last_else = *(--(else_shape.end()));
if(last_else <= 1) if(last_else <= 1)
{ {
auto convert_ins = then_mdl->add_instruction( auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {else_shape.size() - 1}}}), make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
--(--then_mdl->end())); --(--mdl->end()));
then_mdl->replace_return({convert_ins}); mdl->replace_return({convert_ins});
then_mdl->remove_instruction({--convert_ins}); mdl->remove_instruction({--convert_ins});
}
} }
else };
{
auto last_then = *(--(then_shape.end()));
if(last_then <= 1) // Find which dim to unsqueeze
if(then_shape.size() < else_shape.size())
{ {
auto convert_ins = else_mdl->add_instruction( unsqueeze_last_op(then_mdl, else_shape);
make_op("unsqueeze", {{"axes", {then_shape.size() - 1}}}),
--(--else_mdl->end()));
else_mdl->replace_return({convert_ins});
else_mdl->remove_instruction({--convert_ins});
} }
else
{
unsqueeze_last_op(else_mdl, then_shape);
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment