Commit e612b60c authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Avoid unneeded nesting

parent 881a4bd4
......@@ -115,76 +115,78 @@ struct parse_if : op_parser<parse_if>
else_out_shape.type_string());
}
if(not then_out_shape.dynamic() and not else_out_shape.dynamic())
if(then_out_shape.dynamic() or else_out_shape.dynamic())
{
auto then_lens = then_out_shape.lens();
auto else_lens = else_out_shape.lens();
continue;
}
// Throw error if both branches have zero output shapes. Not possible for static
// inputs
if(then_lens.empty() and else_lens.empty())
{
throw_shapes();
}
auto then_lens = then_out_shape.lens();
auto else_lens = else_out_shape.lens();
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0}));
auto literal_ins =
mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
auto unsqueeze_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins);
auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins);
auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
}
else if(else_lens.empty())
// Throw error if both branches have zero output shapes. Not possible for static
// inputs
if(then_lens.empty() and else_lens.empty())
{
throw_shapes();
}
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0}));
auto literal_ins =
mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
auto unsqueeze_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins);
auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins);
auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
}
else if(else_lens.empty())
{
else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(rank_delta == 1)
{
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{
else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
throw_shapes();
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
auto last_then = then_lens.back();
auto last_else = else_lens.back();
if(rank_delta == 1)
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{
throw_shapes();
}
auto last_then = then_lens.back();
auto last_else = else_lens.back();
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, i, else_lens);
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, i, then_lens);
}
unsqueeze_last_op(then_mdl, i, else_lens);
}
else if(rank_delta > 1)
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
throw_shapes();
unsqueeze_last_op(else_mdl, i, then_lens);
}
}
else if(rank_delta > 1)
{
throw_shapes();
}
}
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment