Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
......@@ -96,7 +96,7 @@ instruction_ref insert_common_op(module& m,
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", common.lens()}}), input);
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
......
......@@ -39,9 +39,7 @@ struct find_dot_add
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
......@@ -51,7 +49,7 @@ struct find_dot_add
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
......@@ -72,9 +70,7 @@ struct find_dot_alpha
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
......
......@@ -30,7 +30,7 @@ struct broadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
}
std::string name() const { return "broadcast"; }
......
......@@ -23,7 +23,7 @@ struct multibroadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
return pack(f(self.output_lens, "out_lens"));
}
std::string name() const { return "multibroadcast"; }
......
......@@ -21,7 +21,7 @@ struct transpose
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
return pack(f(self.dims, "permutation"));
}
std::string name() const { return "transpose"; }
......
......@@ -85,7 +85,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
......
......@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
......
......@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
......
......@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
return info.add_instruction(
......
......@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
}
};
......
......@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
l_stride =
info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
......@@ -55,8 +55,10 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
......@@ -71,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
......
......@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
......
......@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
......
......@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
......
......@@ -45,7 +45,8 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto tr_out =
info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction(
......@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
......
......@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
......
......@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
}
auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
......@@ -21,7 +21,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
}
};
......
......@@ -23,19 +23,19 @@ struct parse_where : op_parser<parse_where>
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
{
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond);
cond = info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), cond);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
// compute index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment