Commit 02f359b2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

first correct version of gemm implementation

parent f3ea46e5
...@@ -882,7 +882,7 @@ struct dot ...@@ -882,7 +882,7 @@ struct dot
if(b_size > n_dim) if(b_size > n_dim)
{ {
std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim); std::copy(b.rbegin() + n_dim, b.rend(), out_lens.begin() + n_dim);
} }
std::reverse(out_lens.begin(), out_lens.end()); std::reverse(out_lens.begin(), out_lens.end());
...@@ -956,8 +956,8 @@ struct dot ...@@ -956,8 +956,8 @@ struct dot
is_b_appended = true; is_b_appended = true;
} }
std::size_t dim_0 = a_lens.size() - 2; std::size_t dim_1 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 1; std::size_t dim_0 = b_lens.size() - 2;
if(a_lens[dim_1] != b_lens[dim_0]) if(a_lens[dim_1] != b_lens[dim_0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a.lens()) +
......
...@@ -82,7 +82,7 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -82,7 +82,7 @@ void migemm_impl(tensor_view<T> cmat,
std::size_t nb_dims = b_lens.size(); std::size_t nb_dims = b_lens.size();
auto k = a_lens[na_dims - 1]; auto k = a_lens[na_dims - 1];
assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]); assert(a_lens[na_dims - 1] == b_lens[nb_dims - 2]);
assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]); assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]);
assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]); assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]);
...@@ -92,13 +92,13 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -92,13 +92,13 @@ void migemm_impl(tensor_view<T> cmat,
std::vector<std::size_t> b_idx(nb_dims); std::vector<std::size_t> b_idx(nb_dims);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
std::transform(c_lens.begin() + a_len_diff, std::transform(c_idx.begin() + a_len_diff,
c_lens.end(), c_idx.end(),
a_lens.begin(), a_lens.begin(),
a_idx.begin(), a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; }); [&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(c_lens.begin() + b_len_diff, std::transform(c_idx.begin() + b_len_diff,
c_lens.end(), c_idx.end(),
b_lens.begin(), b_lens.begin(),
b_idx.begin(), b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; }); [&](auto i, auto j) { return (j == 1) ? 0 : i; });
......
...@@ -240,6 +240,125 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -240,6 +240,125 @@ void miopen_gemm::fill_result(const shape& output_shape,
} }
} }
argument miopen_gemm::batch_matmul(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
auto an_dim = a_lens.size();
auto bn_dim = b_lens.size();
auto outn_dim = out_lens.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
rocblas_int m = out_lens[outn_dim - 2];
rocblas_int n = out_lens[outn_dim - 1];
rocblas_int k = a_lens[an_dim - 1];
float beta = 0.0f;
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if (a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{
std::size_t numa_matrices =
std::accumulate(a_batch_lens.begin(), a_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t numb_matrices =
std::accumulate(b_batch_lens.begin(), b_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data()));};
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
stride_b,
to_pointer(args[0]),
lda,
stride_a,
&beta_r,
to_pointer(args[2]),
ldc,
stride_c,
num_matrices);
});
}
else
{
std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_batch_shape{t, out_batch_lens};
std::size_t a_len_diff = outn_dim - an_dim;
std::size_t b_len_diff = outn_dim - bn_dim;
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
auto type_size = output_shape.type_size();
std::vector<std::size_t> a_idx(a_batch_lens.size());
std::vector<std::size_t> b_idx(b_batch_lens.size());
std::transform(out_idx.begin() + a_len_diff,
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
});
});
}
return args[2];
}
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
...@@ -262,20 +381,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -262,20 +381,6 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int n = out_lens[1]; rocblas_int n = out_lens[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[3]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
std::cout << "gpu::gemm, transb = " << transb << std::endl;
std::cout << "gpu::gemm, transa = " << transb << std::endl;
std::cout << "gpu::gemm, m = " << m << std::endl;
std::cout << "gpu::gemm, n = " << n << std::endl;
std::cout << "gpu::gemm, k = " << k << std::endl;
std::cout << "gpu::gemm, lda = " << lda << std::endl;
std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
std::cout << "gpu::gemm, ldc = " << ldc << std::endl;
generic_rocblas_gemm(as, generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -300,7 +405,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -300,7 +405,6 @@ argument miopen_gemm::compute(context& ctx,
// 2 input arguments cases // 2 input arguments cases
// vector inner product // vector inner product
std::size_t type_size = output_shape.type_size();
if(output_shape.elements() == 1) if(output_shape.elements() == 1)
{ {
assert(args[0].get_shape().elements() == args[1].get_shape().elements()); assert(args[0].get_shape().elements() == args[1].get_shape().elements());
...@@ -324,39 +428,49 @@ argument miopen_gemm::compute(context& ctx, ...@@ -324,39 +428,49 @@ argument miopen_gemm::compute(context& ctx,
else if(args[1].get_shape().lens().size() == 1) else if(args[1].get_shape().lens().size() == 1)
{ {
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = a_lens.size() - 2; std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1; std::size_t dim_1 = a_lens.size() - 1;
bool trans = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
rocblas_int m = a_lens[trans ? dim_1 : dim_0]; bool transb = false;
rocblas_int n = a_lens[trans ? dim_0 : dim_1]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = 1;
rocblas_int ldc = 1;
rocblas_int m = a_lens[dim_0];
rocblas_int n = 1;
rocblas_int k = a_lens[dim_1];
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
assert(a_lens.back() == args[1].get_shape().elements()); assert(a_lens.back() == args[1].get_shape().elements());
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data()));
}; };
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ generic_rocblas_batched_gemm(
generic_rocblas_gemv(as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
trans ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
m, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
&alpha_r, m,
to_pointer(args[0], batch_no * m * n * type_size), k,
lda, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
1, ldb,
&beta_r, 0,
to_pointer(args[2], batch_no * n * type_size), to_pointer(args[0]),
1); lda,
} m * k,
&beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
}); });
} }
// vector * matrix // vector * matrix
...@@ -380,26 +494,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -380,26 +494,6 @@ argument miopen_gemm::compute(context& ctx,
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[2]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
std::cout << "gpu::gemm, transb = " << transb << std::endl;
std::cout << "gpu::gemm, transa = " << transb << std::endl;
std::cout << "gpu::gemm, m = " << m << std::endl;
std::cout << "gpu::gemm, n = " << n << std::endl;
std::cout << "gpu::gemm, k = " << k << std::endl;
std::cout << "gpu::gemm, lda = " << lda << std::endl;
std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
std::cout << "gpu::gemm, ldc = " << ldc << std::endl;
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
});
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
...@@ -426,129 +520,11 @@ argument miopen_gemm::compute(context& ctx, ...@@ -426,129 +520,11 @@ argument miopen_gemm::compute(context& ctx,
m * n, m * n,
batch_num); batch_num);
}); });
return args[2];
} }
// (batch) matrix multiplication // (batch) matrix multiplication
else else
{ {
bool transa = args[0].get_shape().transposed(); batch_matmul(ctx, output_shape, args);
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
rocblas_int lda =
args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
rocblas_int ldb =
args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
float beta = 0.0f;
auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0};
for(axis = 2; axis < input_dims; ++axis)
{
if(a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
{
break;
}
}
// The number of matrices that can be computed in one call
// batch_num > 1, we need to call the batch_gemm function,
// otherwise, call the gemm function directly
std::size_t num_matrices =
std::accumulate(a_lens.rbegin() + 2,
(axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t a_len_diff = out_lens.size() - a_lens.size();
std::size_t b_len_diff = out_lens.size() - b_lens.size();
std::vector<std::size_t> a_batch_lens(a_lens.begin(),
a_lens.begin() + a_lens.size() - axis);
std::vector<std::size_t> b_batch_lens(b_lens.begin(),
b_lens.begin() + b_lens.size() - axis);
std::vector<std::size_t> out_batch_lens(out_lens.begin(),
out_lens.begin() + out_lens.size() - axis);
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_batch_shape{t, out_batch_lens};
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
std::vector<std::size_t> a_idx(a_lens.size() - axis);
std::vector<std::size_t> b_idx(b_lens.size() - axis);
std::transform(out_idx.begin() + a_len_diff,
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
if(num_matrices > 1)
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * num_matrices * b_ind * type_size),
ldb,
k * n,
to_pointer(args[0], m * k * num_matrices * a_ind * type_size),
lda,
m * k,
&beta_r,
to_pointer(args[2], m * n * num_matrices * out_ind * type_size),
ldc,
m * n,
num_matrices);
}
// num_matrices per call is 1
else
{
generic_rocblas_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
}
});
});
} }
return args[2]; return args[2];
......
...@@ -21,6 +21,8 @@ struct miopen_gemm ...@@ -21,6 +21,8 @@ struct miopen_gemm
private: private:
void fill_result(const shape& output_shape, const argument& result, const argument& c) const; void fill_result(const shape& output_shape, const argument& result, const argument& c) const;
argument batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment