Commit 8d69498c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup for extending the gemm operator

parent cda6f573
...@@ -840,12 +840,13 @@ struct dot ...@@ -840,12 +840,13 @@ struct dot
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1"); MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
} }
std::size_t n_dims = a.lens().size(); std::size_t dim_0 = a.lens().size() - 2;
if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2]) std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[n_dims - 1] = b.lens()[n_dims - 1]; out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}; };
......
...@@ -469,7 +469,6 @@ struct onnx_parser ...@@ -469,7 +469,6 @@ struct onnx_parser
{ {
transb = parse_value(attributes.at("transB")).at<bool>(); transb = parse_value(attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm = {1, 0}; std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
...@@ -490,10 +489,7 @@ struct onnx_parser ...@@ -490,10 +489,7 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{}); return add_broadcastable_binary_op(l3, l4, op::add{});
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
return dot_res;
} }
instruction_ref instruction_ref
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment