"vscode:/vscode.git/clone" did not exist on "2e414b7c922b516152ca7d9586a76dc3734aa212"
Commit e612b60c authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Avoid unneeded nesting

parent 881a4bd4
...@@ -115,76 +115,78 @@ struct parse_if : op_parser<parse_if> ...@@ -115,76 +115,78 @@ struct parse_if : op_parser<parse_if>
else_out_shape.type_string()); else_out_shape.type_string());
} }
if(not then_out_shape.dynamic() and not else_out_shape.dynamic()) if(then_out_shape.dynamic() or else_out_shape.dynamic())
{ {
auto then_lens = then_out_shape.lens(); continue;
auto else_lens = else_out_shape.lens(); }
// Throw error if both branches have zero output shapes. Not possible for static auto then_lens = then_out_shape.lens();
// inputs auto else_lens = else_out_shape.lens();
if(then_lens.empty() and else_lens.empty())
{
throw_shapes();
}
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) { // Throw error if both branches have zero output shapes. Not possible for static
shape gen_shape(shape(out_shape.type(), {1}, {0})); // inputs
auto literal_ins = if(then_lens.empty() and else_lens.empty())
mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0})); {
auto unsqueeze_ins = mdl->insert_instruction( throw_shapes();
std::prev(mdl->end()), }
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins); auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
auto broad_ins = mdl->insert_instruction( shape gen_shape(shape(out_shape.type(), {1}, {0}));
std::prev(mdl->end()), auto literal_ins =
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}), mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
unsqueeze_ins); auto unsqueeze_ins = mdl->insert_instruction(
auto contig_out = mdl->insert_instruction( std::prev(mdl->end()),
std::prev(mdl->end()), make_op("contiguous"), broad_ins); make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out); literal_ins);
return out_shape.lens(); auto broad_ins = mdl->insert_instruction(
}; std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
// Handle one empty branch by setting output identical to the other unsqueeze_ins);
// need to update the then_shape before we do further checks auto contig_out = mdl->insert_instruction(
if(then_lens.empty()) std::prev(mdl->end()), make_op("contiguous"), broad_ins);
{ mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
then_lens = handle_empty_branch(then_mdl, i, else_out_shape); return out_shape.lens();
} };
else if(else_lens.empty())
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
}
else if(else_lens.empty())
{
else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(rank_delta == 1)
{
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{ {
else_lens = handle_empty_branch(else_mdl, i, then_out_shape); throw_shapes();
} }
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) auto last_then = then_lens.back();
int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size()))); auto last_else = else_lens.back();
if(rank_delta == 1) // Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
// make sure dims are equivalent in static shapes unsqueeze_last_op(then_mdl, i, else_lens);
if(not all_but_last_dims_equal(then_lens, else_lens))
{
throw_shapes();
}
auto last_then = then_lens.back();
auto last_else = else_lens.back();
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, i, else_lens);
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, i, then_lens);
}
} }
else if(rank_delta > 1) else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{ {
throw_shapes(); unsqueeze_last_op(else_mdl, i, then_lens);
} }
} }
else if(rank_delta > 1)
{
throw_shapes();
}
} }
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment