Commit a7975801 authored by xptree's avatar xptree
Browse files

fix cublas gemm call for bf16 input

parent cd8372b3
......@@ -122,12 +122,13 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
// TODO: Support bf16 for HIP
assert(false);
#else
const float alpha_fp32(*alpha), beta_fp32(*beta);
return cublasSgemmEx(handle, transa, transb, m, n, k,
(const float*)alpha,
(const void*)A, CUDA_R_16F, lda,
(const void*)B, CUDA_R_16F, ldb,
(const float*)beta,
(void*)C, CUDA_R_16F, ldc);
(const float*)&alpha_fp32,
(const void*)A, CUDA_R_16BF, lda,
(const void*)B, CUDA_R_16BF, ldb,
(const float*)&beta_fp32,
(void*)C, CUDA_R_16BF, ldc);
#endif
}
#endif // CUBLAS_WRAPPER_H
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment