Unverified Commit c1c19f3e authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #171 from laekov/xptree/fix-bf16

fix cublas gemm call for bf16 input
parents cd8372b3 a7975801
......@@ -122,12 +122,13 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
// TODO: Support bf16 for HIP
assert(false);
#else
const float alpha_fp32(*alpha), beta_fp32(*beta);
return cublasSgemmEx(handle, transa, transb, m, n, k,
(const float*)alpha,
(const void*)A, CUDA_R_16F, lda,
(const void*)B, CUDA_R_16F, ldb,
(const float*)beta,
(void*)C, CUDA_R_16F, ldc);
(const float*)&alpha_fp32,
(const void*)A, CUDA_R_16BF, lda,
(const void*)B, CUDA_R_16BF, ldb,
(const float*)&beta_fp32,
(void*)C, CUDA_R_16BF, ldc);
#endif
}
#endif // CUBLAS_WRAPPER_H
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment