Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
...@@ -96,7 +96,7 @@ instruction_ref insert_common_op(module& m, ...@@ -96,7 +96,7 @@ instruction_ref insert_common_op(module& m,
if(input->get_shape().lens() != common.lens()) if(input->get_shape().lens() != common.lens())
{ {
input = m.insert_instruction( input = m.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", common.lens()}}), input); ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
} }
if(input->get_shape().type() != common.type()) if(input->get_shape().type() != common.type())
{ {
......
...@@ -39,9 +39,7 @@ struct find_dot_add ...@@ -39,9 +39,7 @@ struct find_dot_add
{ {
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins); auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
...@@ -51,7 +49,7 @@ struct find_dot_add ...@@ -51,7 +49,7 @@ struct find_dot_add
{ {
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}}); auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction( auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta); ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast); c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
} }
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins); p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
...@@ -72,9 +70,7 @@ struct find_dot_alpha ...@@ -72,9 +70,7 @@ struct find_dot_alpha
{ {
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins); p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
......
...@@ -30,7 +30,7 @@ struct broadcast ...@@ -30,7 +30,7 @@ struct broadcast
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
} }
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
......
...@@ -23,7 +23,7 @@ struct multibroadcast ...@@ -23,7 +23,7 @@ struct multibroadcast
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_lens, "output_lens")); return pack(f(self.output_lens, "out_lens"));
} }
std::string name() const { return "multibroadcast"; } std::string name() const { return "multibroadcast"; }
......
...@@ -21,7 +21,7 @@ struct transpose ...@@ -21,7 +21,7 @@ struct transpose
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.dims, "dims")); return pack(f(self.dims, "permutation"));
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
......
...@@ -85,7 +85,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -85,7 +85,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = mod->add_instruction( auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]); args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
......
...@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op> ...@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{ {
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction( auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]); args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l); return info.add_instruction(make_op(opd.op_name), args[0], l);
} }
......
...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip> ...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used) if(min_used)
{ {
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg); min_arg);
} }
if(max_used) if(max_used)
{ {
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg); max_arg);
} }
......
...@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction( x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else else
{ {
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]); args[1]);
} }
...@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point); x_zero_point);
} }
else else
{ {
x_zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
} }
return info.add_instruction( return info.add_instruction(
......
...@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand> ...@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
args[0]);
} }
}; };
......
...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements> ...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride =
l_stride); info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx); auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride); auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta); auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
...@@ -55,8 +55,10 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -55,8 +55,10 @@ struct parse_gemm : op_parser<parse_gemm>
} }
} }
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1; l1 =
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2); auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
...@@ -71,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -71,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = info.add_instruction( l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); args[2]);
} }
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal); auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
......
...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar> ...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled = auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction( auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
}; };
......
...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon); auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction( auto epsilon_bcast =
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
; ;
auto bias_bcast = auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); return info.add_instruction(make_op("add"), l5, bias_bcast);
} }
......
...@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = info.add_instruction( bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0); make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = info.add_instruction( bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
......
...@@ -45,7 +45,8 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -45,7 +45,8 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out); auto tr_out =
info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens(); auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction( auto off_val = info.add_instruction(
...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val); auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val = auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val); auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val); return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
......
...@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else else
{ {
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]); args[1]);
} }
...@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point); y_zero_point);
} }
else else
{ {
y_zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
......
...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu> ...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1}) if(lens != std::vector<std::size_t>{1})
{ {
l_alpha = l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma = l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
} }
auto sign_x = info.add_instruction(make_op("sign"), args[0]); auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
...@@ -21,7 +21,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -21,7 +21,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front()); return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
} }
}; };
......
...@@ -23,19 +23,19 @@ struct parse_where : op_parser<parse_where> ...@@ -23,19 +23,19 @@ struct parse_where : op_parser<parse_where>
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens) if(cond->get_shape().lens() != lens)
{ {
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond); cond = info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), cond);
} }
if(args[1]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[1] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens) if(args[2]->get_shape().lens() != lens)
{ {
args[2] = args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
} }
// compute index // compute index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment