Commit 2b1428ff authored by yuguo's avatar yuguo
Browse files

[DCU] fix 2.5 compile issues

parent b4a2489f
...@@ -15,10 +15,10 @@ from typing import List ...@@ -15,10 +15,10 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions.""" """Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"] reqs = ["torch>=2.1", "einops"]
reqs.append( # reqs.append(
"nvdlfw-inspect @" # "nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" # " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
) # )
return reqs return reqs
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -692,6 +692,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -692,6 +692,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Cuda version >=12.2 and <13.0 is required for atomic gemm."); "Cuda version >=12.2 and <13.0 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000, NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."); "Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
#endif
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
......
...@@ -1003,7 +1003,7 @@ static inline int getIntEnv(const char *name, int defval, int minval) ...@@ -1003,7 +1003,7 @@ static inline int getIntEnv(const char *name, int defval, int minval)
*/ */
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) { static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr); NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < compute_num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i])); NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
} }
} }
...@@ -1842,13 +1842,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1842,13 +1842,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if (use_hipblaslt || !use_rocblas) if (use_hipblaslt || !use_rocblas)
{ {
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr; hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) { if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally) // Init hipblaslt handles (once, globally)
static std::once_flag init_flag; static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[num_streams]; static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles); std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset]; handle = hipblaslt_handles[compute_stream_offset];
......
...@@ -132,6 +132,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -132,6 +132,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
*/ */
namespace transformer_engine { namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
constexpr int compute_num_streams = 2;
// Add for batchgemm stream // Add for batchgemm stream
constexpr int num_batchgemm_streams = 1; constexpr int num_batchgemm_streams = 1;
#endif #endif
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "multi_stream.h" #include "multi_stream.h"
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/gemm.h>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
...@@ -51,7 +52,7 @@ cudaEvent_t get_compute_stream_event(int idx) { ...@@ -51,7 +52,7 @@ cudaEvent_t get_compute_stream_event(int idx) {
int get_num_compute_streams() { int get_num_compute_streams() {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static constexpr int num_compute_streams = 2; static constexpr int num_compute_streams = compute_num_streams;
#else #else
static constexpr int num_compute_streams = 4; static constexpr int num_compute_streams = 4;
#endif #endif
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment