"vscode:/vscode.git/clone" did not exist on "50141f2a6daffbcb1bd62f4698daced0e3343889"
Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
......@@ -96,7 +96,7 @@ instruction_ref insert_common_op(module& m,
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", common.lens()}}), input);
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
......
......@@ -39,9 +39,7 @@ struct find_dot_add
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
......@@ -51,7 +49,7 @@ struct find_dot_add
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
......@@ -72,9 +70,7 @@ struct find_dot_alpha
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
......
......@@ -30,7 +30,7 @@ struct broadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
}
std::string name() const { return "broadcast"; }
......
......@@ -23,7 +23,7 @@ struct multibroadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
return pack(f(self.output_lens, "out_lens"));
}
std::string name() const { return "multibroadcast"; }
......
......@@ -21,7 +21,7 @@ struct transpose
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
return pack(f(self.dims, "permutation"));
}
std::string name() const { return "transpose"; }
......
......@@ -85,7 +85,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
......
......@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
......
......@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
......
......@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
return info.add_instruction(
......
......@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
}
};
......
......@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
l_stride =
info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
......@@ -55,8 +55,10 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
......@@ -71,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
......
......@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
......
......@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
......
......@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
......
......@@ -45,7 +45,8 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto tr_out =
info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction(
......@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
......
......@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
......
......@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
}
auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
......@@ -21,7 +21,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
}
};
......
......@@ -23,19 +23,19 @@ struct parse_where : op_parser<parse_where>
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
{
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond);
cond = info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), cond);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
// compute index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment