"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "65a2cb16857c99cb75f521fcf550febe96bc07e0"
Commit 34aeda23 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix tidy issue and clean up logic

parent 372206b3
...@@ -115,57 +115,54 @@ struct parse_if : op_parser<parse_if> ...@@ -115,57 +115,54 @@ struct parse_if : op_parser<parse_if>
{ {
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0)); else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0));
} }
else
{
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta <= 1)
{
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_A,
std::vector<size_t>& lens_B) {
if(lens_A.size() <= lens_B.size())
{
return equal(lens_A.begin(), lens_A.end(), lens_B.begin());
}
else
{
return equal(lens_B.begin(), lens_B.end(), lens_A.begin());
}
};
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{
throw_shapes();
}
auto unsqueeze_last_op = [](module_ref& mdl, // check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
const std::vector<size_t>& out_shape) { int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
auto last_then = *(std::prev(then_lens.end())); if(dim_delta == 1)
auto last_else = *(std::prev(else_lens.end())); {
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_a,
// Find which dim to unsqueeze std::vector<size_t>& lens_b) {
if((then_lens.size() < else_lens.size()) && (last_else == 1)) if(lens_a.size() <= lens_b.size())
{ {
unsqueeze_last_op(then_mdl, else_lens); return equal(lens_a.begin(), lens_a.end(), lens_b.begin());
} }
else if((then_lens.size() > else_lens.size()) && (last_then == 1)) else
{ {
unsqueeze_last_op(else_mdl, then_lens); return equal(lens_b.begin(), lens_b.end(), lens_a.begin());
} }
} };
else
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{ {
throw_shapes(); throw_shapes();
} }
auto unsqueeze_last_op = [](module_ref& mdl, const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
auto last_then = *(std::prev(then_lens.end()));
auto last_else = *(std::prev(else_lens.end()));
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, else_lens);
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, then_lens);
}
}
else if(dim_delta > 1)
{
throw_shapes();
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment