Commit 783e9474 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

change type for alpha beta

parent 297bfdd0
...@@ -83,8 +83,8 @@ void gemm_impl(context& ctx, ...@@ -83,8 +83,8 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = compute_fp32 ? alpha : as(alpha); auto alpha_r = as(alpha);
auto beta_r = compute_fp32 ? beta : as(beta); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
...@@ -104,64 +104,124 @@ void gemm_impl(context& ctx, ...@@ -104,64 +104,124 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), if(compute_fp32)
transb ? rocblas_operation_transpose : rocblas_operation_none, rocblas_invoke(&rocblas_gemm_ex,
transa ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
n, transb ? rocblas_operation_transpose : rocblas_operation_none,
m, transa ? rocblas_operation_transpose : rocblas_operation_none,
k, n,
&alpha_r, m,
to_pointer(args.at(1)), k,
arg_type, &alpha,
ldb, to_pointer(args.at(1)),
to_pointer(args.at(0)), arg_type,
arg_type, ldb,
lda, to_pointer(args.at(0)),
&beta_r, arg_type,
to_pointer(args[2]), lda,
output_type, &beta,
ldc, to_pointer(args[2]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), output_type,
output_type, ldc,
ldc, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
compute_type, output_type,
rocblas_gemm_algo_standard, ldc,
0, compute_type,
flag); rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
else else
{ {
rocblas_invoke(&rocblas_gemm_strided_batched_ex, if(compute_fp32)
ctx.get_stream().get_rocblas(), rocblas_invoke(&rocblas_gemm_strided_batched_ex,
transb ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
transa ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
n, transa ? rocblas_operation_transpose : rocblas_operation_none,
m, n,
k, m,
&alpha_r, k,
to_pointer(args.at(1)), &alpha,
arg_type, to_pointer(args.at(1)),
ldb, arg_type,
k * n, ldb,
to_pointer(args.at(0)), k * n,
arg_type, to_pointer(args.at(0)),
lda, arg_type,
m * k, lda,
&beta_r, m * k,
to_pointer(args[2]), &beta,
output_type, to_pointer(args[2]),
ldc, output_type,
m * n, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), m * n,
output_type, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
ldc, output_type,
m * n, ldc,
num_matrices, m * n,
compute_type, num_matrices,
rocblas_gemm_algo_standard, compute_type,
0, rocblas_gemm_algo_standard,
flag); 0,
flag);
else
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment