Commit 8b3bf2fe authored by Shucai Xiao's avatar Shucai Xiao
Browse files

a naive optimization of rocblas_gemm call

parent 3b5c6c7f
...@@ -454,26 +454,50 @@ argument miopen_gemm::compute(context& ctx, ...@@ -454,26 +454,50 @@ argument miopen_gemm::compute(context& ctx,
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
generic_rocblas_batched_gemm( if (num_matrices > 1)
as, {
ctx.get_stream().get_rocblas(), generic_rocblas_batched_gemm(
transb ? rocblas_operation_transpose : rocblas_operation_none, as,
transa ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
n, transb ? rocblas_operation_transpose : rocblas_operation_none,
m, transa ? rocblas_operation_transpose : rocblas_operation_none,
k, n,
&alpha_r, m,
to_pointer(args[1], k * n * num_matrices * b_ind), k,
ldb, &alpha_r,
k * n, to_pointer(args[1], k * n * num_matrices * b_ind),
to_pointer(args[0], m * k * num_matrices * a_ind), ldb,
lda, k * n,
m * k, to_pointer(args[0], m * k * num_matrices * a_ind),
&beta_r, lda,
to_pointer(args[2], m * n * num_matrices * out_ind), m * k,
ldc, &beta_r,
m * n, to_pointer(args[2], m * n * num_matrices * out_ind),
num_matrices); ldc,
m * n,
num_matrices);
}
// num_matrices per call is 1
else
{
generic_rocblas_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * num_matrices * b_ind),
ldb,
to_pointer(args[0], m * k * num_matrices * a_ind),
lda,
&beta_r,
to_pointer(args[2], m * n * num_matrices * out_ind),
ldc);
}
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment