Commit 0b639e94 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the calling of rocmblas function.

parent 3d200e1c
...@@ -117,26 +117,47 @@ argument miopen_gemm::compute(context& ctx, ...@@ -117,26 +117,47 @@ argument miopen_gemm::compute(context& ctx,
auto alpha_r = to_rocblas_type(as(alpha)); auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as, // call the strided implementation only if there are multiple matrices
ctx.get_stream().get_rocblas(), if (batch_num > 1)
transb ? rocblas_operation_transpose : rocblas_operation_none, {
transa ? rocblas_operation_transpose : rocblas_operation_none, generic_rocblas_batched_gemm(as,
n, ctx.get_stream().get_rocblas(),
m, transb ? rocblas_operation_transpose : rocblas_operation_none,
k, transa ? rocblas_operation_transpose : rocblas_operation_none,
&alpha_r, n,
to_pointer(args[1]), m,
ldb, k,
k * n, &alpha_r,
to_pointer(args[0]), to_pointer(args[1]),
lda, ldb,
m * k, k * n,
&beta_r, to_pointer(args[0]),
to_pointer(args[2]), lda,
ldc, m * k,
m * n, &beta_r,
batch_num); to_pointer(args[2]),
ldc,
m * n,
batch_num);
}
else
{
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
to_pointer(args[2]),
ldc);
}
}); });
return args[2]; return args[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment