"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "31d2ab4aff51c537dd4bc82451efbc194e0b8f2b"
Commit 2b1428ff authored by yuguo's avatar yuguo
Browse files

[DCU] fix 2.5 compile issues

parent b4a2489f
......@@ -15,10 +15,10 @@ from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"]
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# reqs.append(
# "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
return reqs
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......@@ -692,6 +692,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Cuda version >=12.2 and <13.0 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
#endif
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
......
......@@ -1003,7 +1003,7 @@ static inline int getIntEnv(const char *name, int defval, int minval)
*/
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < num_streams; i++) {
for (int i = 0; i < compute_num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
}
}
......@@ -1842,13 +1842,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if (use_hipblaslt || !use_rocblas)
{
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams);
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[num_streams];
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
......
......@@ -132,6 +132,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
*/
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
constexpr int compute_num_streams = 2;
// Add for batchgemm stream
constexpr int num_batchgemm_streams = 1;
#endif
......
......@@ -10,6 +10,7 @@
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/gemm.h>
#include <mutex>
#include <vector>
......@@ -51,7 +52,7 @@ cudaEvent_t get_compute_stream_event(int idx) {
int get_num_compute_streams() {
#ifdef __HIP_PLATFORM_AMD__
static constexpr int num_compute_streams = 2;
static constexpr int num_compute_streams = compute_num_streams;
#else
static constexpr int num_compute_streams = 4;
#endif
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment