Commit e7c17bd6 authored by charlie's avatar charlie
Browse files

Tidy style fix

parent 801a349c
......@@ -39,19 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto A = args[0];
auto B = args[1];
if(A->get_shape().ndim() != 2 or B->get_shape().ndim() != 2)
auto a_arg = args[0];
auto b_arg = args[1];
if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
{
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(A->get_shape().ndim()) + "B is rank " +
std::to_string(B->get_shape().ndim()));
std::to_string(a_arg->get_shape().ndim()) + "B is rank " +
std::to_string(b_arg->get_shape().ndim()));
}
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
float alpha = 1.0f;
float beta = 1.0f;
bool trans_a = false;
bool trans_b = false;
if(contains(info.attributes, "alpha"))
{
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
......@@ -62,11 +62,11 @@ struct parse_gemm : op_parser<parse_gemm>
}
if(contains(info.attributes, "transA"))
{
transa = parser.parse_value(info.attributes.at("transA")).at<bool>();
trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
}
if(contains(info.attributes, "transB"))
{
transb = parser.parse_value(info.attributes.at("transB")).at<bool>();
trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(2);
......@@ -74,24 +74,28 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto dot_type = A->get_shape().type();
auto dot_type = a_arg->get_shape().type();
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
A = info.add_broadcastable_binary_op("mul", alpha_literal, A);
a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(A->get_shape().type() != dot_type)
if(a_arg->get_shape().type() != dot_type)
{
A = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), A);
a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
}
}
A = (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), A) : A;
B = (transb) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
a_arg = (trans_a)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
: a_arg;
b_arg = (trans_b)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), A, B);
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3)
{
......@@ -104,24 +108,24 @@ struct parse_gemm : op_parser<parse_gemm>
}
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{
auto out_lens = A->get_shape().lens();
out_lens.back() = B->get_shape().lens().back();
auto C = args[2];
auto C_lens = C->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), C_lens.begin(), C_lens.end()))
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
C = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_C = info.add_broadcastable_binary_op("mul", C, beta_literal);
if(beta_C->get_shape().type() != dot_type)
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_c->get_shape().type() != dot_type)
{
beta_C = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_C);
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_c);
}
return info.add_instruction(make_op("add"), ret, beta_C);
return info.add_instruction(make_op("add"), ret, beta_c);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment