Commit f9c38c09 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

support the case of multiple matrices as inputs with C, but C should be the same shape as A * B

parent ca28e1e8
...@@ -808,10 +808,12 @@ struct gather ...@@ -808,10 +808,12 @@ struct gather
}; };
// The dot operation is combination of the onnx GEMM and MatMul operators. // The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C, // For GEMM, it support two cases: 1) in the formula alpha * AB + beta * C,
// in which C is broadcastable to the shape of AB. For the transpose of A // A and B are 2-D matrics and C is broadcastable to the shape of A*B. For
// and B, we add a tranpose operator beforehand if the onnx gemm operator // the transpose of A and B, we add a tranpose operator beforehand if the
// indicates a transpose. // onnx gemm operator indicates a transpose required. 2) A and B are more
// than 2-D, then the dims except the last 2-D in A and B need to be the
// same, and C should be the same shape as A * B
// For MatMul, it has the same definition as the numpy.matmul, which means // For MatMul, it has the same definition as the numpy.matmul, which means
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix, // A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
// for 1-dim of B, it is a matrix * vector. Note that there is not support // for 1-dim of B, it is a matrix * vector. Note that there is not support
...@@ -893,36 +895,66 @@ struct dot ...@@ -893,36 +895,66 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// If there are 3 inputs, then A and B must be matrices and // If there are 3 inputs, there are two scenarios:
// C should be broadcastable to A * B // 1. A and B are 2-D matrices and C is broadcastable to A * B
// 2. A and B are stack of matrices, then shape for the batch
// should be the same for A and B, and C is the same shape
// as A * B (For now, we add this requirement to simplify the
// implementation. we can remove this requirement later)
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
check_shapes{inputs, *this}.has(3).same_type();
check_shapes{{inputs[0]}, *this}.only_dims(2);
check_shapes{{inputs[1]}, *this}.only_dims(2);
auto a_lens = inputs[0].lens(); auto a_lens = inputs[0].lens();
auto b_lens = inputs[1].lens(); auto b_lens = inputs[1].lens();
auto out_lens = a_lens;
auto t = inputs[0].type(); auto t = inputs[0].type();
if(a_lens[1] != b_lens[0]) if (inputs[1].lens().size() > 2)
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) + if(!std::equal(a_lens.rbegin() + 2, a_lens.rend(), b_lens.rbegin() + 2))
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}"); {
} MIGRAPHX_THROW("DOT: dimension mismatch, operand A: {" + to_string_range(a_lens) +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}");
}
auto out_lens = a_lens; std::size_t dim_0 = a_lens.size() - 2;
out_lens[1] = b_lens[1]; std::size_t dim_1 = a_lens.size() - 1;
if(a_lens[dim_1] != b_lens[dim_0])
// check whether C is broadcastable to A * B MIGRAPHX_THROW("Inner dimensions do not match, operand A: {" + to_string_range(a_lens) +
auto c_lens = inputs[2].lens(); "}, operand B: {" + to_string_range(b_lens) + "}");
if(c_lens.size() > 2 || out_lens[dim_1] = b_lens[dim_1];
(c_lens.size() == 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[1])) ||
(c_lens.size() == 2 && (c_lens[0] != 1 && c_lens[0] != a_lens[0])) || // C should be the same shape as A * B
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != b_lens[1]))) auto c_lens = inputs[2].lens();
if(!std::equal(c_lens.begin(), c_lens.end(), out_lens.begin()))
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
}
else
{ {
MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) + check_shapes{inputs, *this}.has(3).same_type();
"} is not broadcastable to A * B {" + to_string_range(out_lens) + check_shapes{{inputs[0]}, *this}.only_dims(2);
"}"); check_shapes{{inputs[1]}, *this}.only_dims(2);
if(a_lens[1] != b_lens[0])
{
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}");
}
out_lens[1] = b_lens[1];
// check whether C is broadcastable to A * B
auto c_lens = inputs[2].lens();
if(c_lens.size() > 2 ||
(c_lens.size() == 1 && (c_lens[0] != 1 && c_lens[0] != b_lens[1])) ||
(c_lens.size() == 2 && (c_lens[0] != 1 && c_lens[0] != a_lens[0])) ||
(c_lens.size() == 2 && (c_lens[1] != 1 && c_lens[1] != b_lens[1])))
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(c_lens) +
"} is not broadcastable to A * B {" + to_string_range(out_lens) +
"}");
}
} }
return {t, out_lens}; return {t, out_lens};
......
...@@ -369,23 +369,25 @@ argument miopen_gemm::compute(context& ctx, ...@@ -369,23 +369,25 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
if(is_3inputs) if(is_3inputs)
{ {
fill_result(output_shape, args[3], args[2]); fill_result(output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[3].get_shape().strides()[0]; rocblas_int ldc = args[3].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as,
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -395,11 +397,15 @@ argument miopen_gemm::compute(context& ctx, ...@@ -395,11 +397,15 @@ argument miopen_gemm::compute(context& ctx,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
ldb, ldb,
k * n,
to_pointer(args[0]), to_pointer(args[0]),
lda, lda,
m * k,
&beta_r, &beta_r,
to_pointer(args[3]), to_pointer(args[3]),
ldc); ldc,
m * n,
num_matrices);
}); });
......
...@@ -420,6 +420,20 @@ TEST_CASE(dot) ...@@ -420,6 +420,20 @@ TEST_CASE(dot)
s_m2); s_m2);
} }
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 7}},
migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 4, 7}},
migraphx::op::dot{}, s_m1, s_m2);
}
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}};
...@@ -433,14 +447,14 @@ TEST_CASE(dot) ...@@ -433,14 +447,14 @@ TEST_CASE(dot)
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 3, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); throws_shape(migraphx::op::dot{}, s_m1, s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); throws_shape(migraphx::op::dot{}, s_m1, s_m2);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment