Unverified Commit 406e2c9d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix incorrect version checks for atomic GEMM (#2095)



* Fix incorrect version checks for atomic GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 96944a81
......@@ -517,22 +517,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&epilogue, sizeof(epilogue)));
if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
......@@ -658,20 +658,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
using namespace transformer_engine;
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
NVTE_CHECK(
cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
const Tensor *inputA = convertNVTETensorCheck(A);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment