Commit 8d69498c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup for extending the gemm operator

parent cda6f573
......@@ -840,12 +840,13 @@ struct dot
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
}
std::size_t n_dims = a.lens().size();
if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2])
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
auto out_lens = a.lens();
out_lens[n_dims - 1] = b.lens()[n_dims - 1];
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
};
......
......@@ -469,7 +469,6 @@ struct onnx_parser
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
......@@ -490,10 +489,7 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
return dot_res;
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
}
instruction_ref
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment