Commit 25e8c688 authored by huangwb's avatar huangwb
Browse files

first runnable TGI changes on DCU platform

parent 2d0a7173
...@@ -1392,9 +1392,11 @@ fn main() -> Result<(), LauncherError> { ...@@ -1392,9 +1392,11 @@ fn main() -> Result<(), LauncherError> {
vec![] vec![]
} }
_ => { _ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; // let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}"); // tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs // cuda_graphs
tracing::info!("Currently disable cuda graphs by default,may enable in the future");
vec![]
} }
}; };
......
...@@ -19,8 +19,10 @@ gen-server: ...@@ -19,8 +19,10 @@ gen-server:
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_rocm.txt
pip install -e ".[bnb, accelerate, quantize, peft, outlines]" # pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
pip install -e ".[accelerate, quantize, peft, outlines]"
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
......
...@@ -46,10 +46,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) ...@@ -46,10 +46,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } // __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif #endif
#endif #endif
......
...@@ -44,10 +44,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) ...@@ -44,10 +44,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } // __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif #endif
#endif #endif
......
...@@ -23,6 +23,36 @@ ...@@ -23,6 +23,36 @@
#include "q_gemm_kernel.cuh" #include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh" #include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
void gemm_half_q_half_cuda_part void gemm_half_q_half_cuda_part
( (
const half* a, const half* a,
......
...@@ -6,6 +6,7 @@ extra_cuda_cflags = ["-lineinfo", "-O3"] ...@@ -6,6 +6,7 @@ extra_cuda_cflags = ["-lineinfo", "-O3"]
if torch.version.hip: if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_cuda_cflags += ["-DUSE_ROCM"]
extra_compile_args = { extra_compile_args = {
"nvcc": extra_cuda_cflags, "nvcc": extra_cuda_cflags,
......
...@@ -49,10 +49,10 @@ grpcio-tools = "^1.51.1" ...@@ -49,10 +49,10 @@ grpcio-tools = "^1.51.1"
pytest = "^7.3.0" pytest = "^7.3.0"
[[tool.poetry.source]] #[[tool.poetry.source]]
name = "pytorch-gpu-src" #name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu121" #url = "https://download.pytorch.org/whl/cu121"
priority = "explicit" #priority = "explicit"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
......
...@@ -69,10 +69,10 @@ try: ...@@ -69,10 +69,10 @@ try:
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral # from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx # from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
...@@ -87,8 +87,8 @@ if FLASH_ATTENTION: ...@@ -87,8 +87,8 @@ if FLASH_ATTENTION:
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral) __all__.append(FlashMistral)
__all__.append(FlashMixtral) # __all__.append(FlashMixtral)
__all__.append(FlashDbrx) # __all__.append(FlashDbrx)
__all__.append(FlashPhi) __all__.append(FlashPhi)
__all__.append(FlashQwen2) __all__.append(FlashQwen2)
__all__.append(FlashStarcoder2) __all__.append(FlashStarcoder2)
......
...@@ -33,7 +33,7 @@ try: ...@@ -33,7 +33,7 @@ try:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
) )
if not (is_sm8x or is_sm90): if not (is_sm8x or is_sm90) and IS_CUDA_SYSTEM:
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for " f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2" "Flash Attention V2"
......
import torch import torch
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from loguru import logger
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
...@@ -21,7 +22,8 @@ def reshape_and_cache( ...@@ -21,7 +22,8 @@ def reshape_and_cache(
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import cache_ops from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots.int())
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else: else:
raise ValueError("vllm is not supported on your system") raise ValueError("vllm is not supported on your system")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment