Commit d45bd3ba authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 783e9474
...@@ -107,121 +107,121 @@ void gemm_impl(context& ctx, ...@@ -107,121 +107,121 @@ void gemm_impl(context& ctx,
if(compute_fp32) if(compute_fp32)
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha, &alpha,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta, &beta,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
else else
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
} }
else else
{ {
if(compute_fp32) if(compute_fp32)
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha, &alpha,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
k * n, k * n,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta, &beta,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, m * n,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
else else
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
k * n, k * n,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, m * n,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment