"vllm/vscode:/vscode.git/clone" did not exist on "526df28fb250c90970cd3db0f68a21797dbe555f"
Unverified Commit 406e2c9d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix incorrect version checks for atomic GEMM (#2095)



* Fix incorrect version checks for atomic GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 96944a81
...@@ -517,22 +517,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -517,22 +517,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
if (counter != nullptr) { if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR( NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION); CUBLAS_VERSION);
#endif #endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000 CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version()); cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version()); cublas_version());
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1; if (n_split == 0) n_split = 1;
...@@ -658,20 +658,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -658,20 +658,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
using namespace transformer_engine; using namespace transformer_engine;
// Check CUDA and cuBLAS versions // Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", NVTE_ERROR(
CUBLAS_VERSION); "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif #endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, NVTE_CHECK(
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
cuda::cudart_version()); "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK( NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000, cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version()); cublas_version());
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment