Unverified Commit c56646e4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix compute type for GEMM (#296)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 574f1b41
......@@ -104,10 +104,11 @@ void cublas_gemm(const Tensor *inputA,
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
cublasComputeType_t gemm_compute_type = (A_type == CUDA_R_8F_E5M2 || B_type == CUDA_R_8F_E5M2)
? CUBLAS_COMPUTE_32F
: CUBLAS_COMPUTE_32F_FAST_TF32;
// Use TF32 only for pure FP32 GEMM.
cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
if (A_type == CUDA_R_32F && B_type == CUDA_R_32F && D_type == CUDA_R_32F) {
gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
}
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment