Commit ad3c4c1d authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

use void pointer to select alpha beta

parent 4fffcdd5
...@@ -83,9 +83,22 @@ void gemm_impl(context& ctx, ...@@ -83,9 +83,22 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v{&alpha_r};
void* beta_v{&beta_r};
if(compute_fp32)
{
alpha_v = α
beta_v = β
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -100,128 +113,64 @@ void gemm_impl(context& ctx, ...@@ -100,128 +113,64 @@ void gemm_impl(context& ctx,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1)
{ {
// the rocblas_gemm API handles inputs and output matrices as rocblas_invoke(&rocblas_gemm_ex,
// column-major format. When doing a C = A * B, we actually do ctx.get_stream().get_rocblas(),
// C^T = (B^T) * (A^T). That is the reason we input args[1] as transb ? rocblas_operation_transpose : rocblas_operation_none,
// A and args[0] as B in calling the rocblas_gemm. transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
if(compute_fp32) m,
rocblas_invoke(&rocblas_gemm_ex, k,
ctx.get_stream().get_rocblas(), alpha_v,
transb ? rocblas_operation_transpose : rocblas_operation_none, to_pointer(args.at(1)),
transa ? rocblas_operation_transpose : rocblas_operation_none, arg_type,
n, ldb,
m, to_pointer(args.at(0)),
k, arg_type,
&alpha, lda,
to_pointer(args.at(1)), beta_v,
arg_type, to_pointer(args[2]),
ldb, output_type,
to_pointer(args.at(0)), ldc,
arg_type, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
lda, output_type,
&beta, ldc,
to_pointer(args[2]), compute_type,
output_type, rocblas_gemm_algo_standard,
ldc, 0,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), flag);
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
else else
{ {
if(compute_fp32) rocblas_invoke(&rocblas_gemm_strided_batched_ex,
rocblas_invoke(&rocblas_gemm_strided_batched_ex, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, n,
n, m,
m, k,
k, alpha_v,
&alpha, to_pointer(args.at(1)),
to_pointer(args.at(1)), arg_type,
arg_type, ldb,
ldb, k * n,
k * n, to_pointer(args.at(0)),
to_pointer(args.at(0)), arg_type,
arg_type, lda,
lda, m * k,
m * k, beta_v,
&beta, to_pointer(args[2]),
to_pointer(args[2]), output_type,
output_type, ldc,
ldc, m * n,
m * n, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), output_type,
output_type, ldc,
ldc, m * n,
m * n, num_matrices,
num_matrices, compute_type,
compute_type, rocblas_gemm_algo_standard,
rocblas_gemm_algo_standard, 0,
0, flag);
flag);
else
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment