Commit e7c17bd6 authored by charlie's avatar charlie
Browse files

Tidy style fix

parent 801a349c
...@@ -39,19 +39,19 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -39,19 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto A = args[0]; auto a_arg = args[0];
auto B = args[1]; auto b_arg = args[1];
if(A->get_shape().ndim() != 2 or B->get_shape().ndim() != 2) if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
{ {
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " + MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(A->get_shape().ndim()) + "B is rank " + std::to_string(a_arg->get_shape().ndim()) + "B is rank " +
std::to_string(B->get_shape().ndim())); std::to_string(b_arg->get_shape().ndim()));
} }
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
bool transa = false; bool trans_a = false;
bool transb = false; bool trans_b = false;
if(contains(info.attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>(); alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
...@@ -62,11 +62,11 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -62,11 +62,11 @@ struct parse_gemm : op_parser<parse_gemm>
} }
if(contains(info.attributes, "transA")) if(contains(info.attributes, "transA"))
{ {
transa = parser.parse_value(info.attributes.at("transA")).at<bool>(); trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
} }
if(contains(info.attributes, "transB")) if(contains(info.attributes, "transB"))
{ {
transb = parser.parse_value(info.attributes.at("transB")).at<bool>(); trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm(2); std::vector<int64_t> perm(2);
...@@ -74,24 +74,28 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -74,24 +74,28 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto dot_type = A->get_shape().type(); auto dot_type = a_arg->get_shape().type();
if(alpha != 1.0f) if(alpha != 1.0f)
{ {
auto alpha_literal = info.add_literal(alpha); auto alpha_literal = info.add_literal(alpha);
A = info.add_broadcastable_binary_op("mul", alpha_literal, A); a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(A->get_shape().type() != dot_type) if(a_arg->get_shape().type() != dot_type)
{ {
A = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), A); a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
} }
} }
A = (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), A) : A; a_arg = (trans_a)
B = (transb) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
: a_arg;
b_arg = (trans_b)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot"), A, B); auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -104,24 +108,24 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -104,24 +108,24 @@ struct parse_gemm : op_parser<parse_gemm>
} }
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0) if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{ {
auto out_lens = A->get_shape().lens(); auto out_lens = a_arg->get_shape().lens();
out_lens.back() = B->get_shape().lens().back(); out_lens.back() = b_arg->get_shape().lens().back();
auto C = args[2]; auto c_arg = args[2];
auto C_lens = C->get_shape().lens(); auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), C_lens.begin(), C_lens.end())) if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{ {
C = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), c_arg = info.add_instruction(
args[2]); make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
} }
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_C = info.add_broadcastable_binary_op("mul", C, beta_literal); auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_C->get_shape().type() != dot_type) if(beta_c->get_shape().type() != dot_type)
{ {
beta_C = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_C); beta_c);
} }
return info.add_instruction(make_op("add"), ret, beta_C); return info.add_instruction(make_op("add"), ret, beta_c);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment