Unverified Commit c1c19f3e authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #171 from laekov/xptree/fix-bf16

fix cublas gemm call for bf16 input
parents cd8372b3 a7975801
...@@ -122,12 +122,13 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle, ...@@ -122,12 +122,13 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
// TODO: Support bf16 for HIP // TODO: Support bf16 for HIP
assert(false); assert(false);
#else #else
const float alpha_fp32(*alpha), beta_fp32(*beta);
return cublasSgemmEx(handle, transa, transb, m, n, k, return cublasSgemmEx(handle, transa, transb, m, n, k,
(const float*)alpha, (const float*)&alpha_fp32,
(const void*)A, CUDA_R_16F, lda, (const void*)A, CUDA_R_16BF, lda,
(const void*)B, CUDA_R_16F, ldb, (const void*)B, CUDA_R_16BF, ldb,
(const float*)beta, (const float*)&beta_fp32,
(void*)C, CUDA_R_16F, ldc); (void*)C, CUDA_R_16BF, ldc);
#endif #endif
} }
#endif // CUBLAS_WRAPPER_H #endif // CUBLAS_WRAPPER_H
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment