Commit fab1acce authored by zhuwenwen's avatar zhuwenwen
Browse files

[Feature] Support vllm v0.20.0

parent 88d34c64
...@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) ...@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13" "3.14") set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13" "3.14")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201") set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201;gfx928;gfx936;gfx938")
# ROCm installation prefix. Default to /opt/rocm but allow override via # ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake. # -DROCM_PATH=/your/rocm/path when invoking cmake.
...@@ -1240,7 +1240,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ...@@ -1240,7 +1240,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif() endif()
# For CUDA and HIP builds also build the triton_kernels external package. # For CUDA and HIP builds also build the triton_kernels external package.
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/triton_kernels.cmake) include(cmake/external_projects/triton_kernels.cmake)
endif() endif()
......
...@@ -931,6 +931,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -931,6 +931,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3); vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2);
}
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
...@@ -1156,9 +1172,9 @@ __global__ void cp_gather_and_upconvert_fp8_kv_cache( ...@@ -1156,9 +1172,9 @@ __global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w); const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM #ifdef USE_ROCM
const bf16_8_t bf16_lo = const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale); fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, vllm::Fp8KVCacheDataType::kFp8E4M3);
const bf16_8_t bf16_hi = const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale); fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, vllm::Fp8KVCacheDataType::kFp8E4M3);
#else #else
const bf16_8_t bf16_lo = const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3); fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <cassert> #include <cassert>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#else #else
#include <cuda_bf16.h> #include <cuda_bf16.h>
......
...@@ -40,15 +40,15 @@ ...@@ -40,15 +40,15 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL #define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000 // #if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below // // On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0 // // implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() { // __device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier(); // __builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
} // }
#endif // #endif
#else #else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#endif #endif
......
...@@ -12,7 +12,9 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa ...@@ -12,7 +12,9 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#ifndef USE_ROCM
#include "compat.cuh" #include "compat.cuh"
#endif
#include "matrix_view.cuh" #include "matrix_view.cuh"
#include "qdq_2.cuh" #include "qdq_2.cuh"
#include "qdq_3.cuh" #include "qdq_3.cuh"
......
...@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, ...@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x = val / scale; x = val / scale;
} }
float r = // float r =
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>)); // fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM #ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia // Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn // Currently only support fp8_type = c10::Float8_e4m3fn
return fp8::vec_conversion<fp8_type, float>(r); return fp8::vec_conversion<fp8_type, float>(r);
#else #else
fp8_type *test;
uint8_t test_uint8 = fp8::float_to_fp8_e4m3(x);
test = (fp8_type*)(&test_uint8);
return *test;
// Use hardware cvt instruction for fp8 on rocm // Use hardware cvt instruction for fp8 on rocm
return fp8::cvt_c10<fp8_type>(r); // return fp8::cvt_c10<fp8_type>(r);
#endif #endif
} }
......
...@@ -16,8 +16,13 @@ packaging>=24.2 ...@@ -16,8 +16,13 @@ packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0
setuptools-scm>=8 setuptools-scm>=8
runai-model-streamer[s3,gcs,azure]==0.15.7 runai-model-streamer[s3,gcs,azure]==0.15.7
conch-triton-kernels==1.2.1 # conch-triton-kernels==1.2.1
timm>=1.0.17 timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm # amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py # To be consistent with test_quark.py
amd-quark>=0.8.99 amd-quark>=0.8.99
# Other necessary dependencies
torch == 2.10.0
torchvision == 0.25.0
flash_attn == 2.8.3
...@@ -20,6 +20,12 @@ from setuptools import Extension, setup ...@@ -20,6 +20,12 @@ from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
from setuptools_scm import get_version from setuptools_scm import get_version
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from typing import Optional, Union
pwd = os.path.dirname(os.path.abspath(__file__))
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
...@@ -365,7 +371,7 @@ class cmake_build_ext(build_ext): ...@@ -365,7 +371,7 @@ class cmake_build_ext(build_ext):
os.makedirs(os.path.dirname(dst_file), exist_ok=True) os.makedirs(os.path.dirname(dst_file), exist_ok=True)
self.copy_file(file, dst_file) self.copy_file(file, dst_file)
if _is_cuda() or _is_hip(): if _is_cuda():
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib # copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
# to current directory so that they can be included in the editable # to current directory so that they can be included in the editable
# build # build
...@@ -895,6 +901,94 @@ def get_nvcc_cuda_version() -> Version: ...@@ -895,6 +901,94 @@ def get_nvcc_cuda_version() -> Version:
return nvcc_cuda_version return nvcc_cuda_version
def get_sha(root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str:
command = "git config --global --add safe.directory "+pwd
subprocess.run(command, shell=True, capture_output=False, text=True)
vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
major, minor, _ = torch.__version__.split('.')
if add_git_version:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das.' + sha[:7]
else:
version = 'das'
# dtk version
if os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version
new_version_content = f"""
try:
__version__ = "0.20.0"
__version_tuple__ = (0, 20, 0)
__hcu_version__ = f'0.20.0+{version}'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
def _prev_minor_version():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
"""
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write(new_version_content)
file.close()
def get_version():
get_version_add()
version_file = 'vllm/version.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__hcu_version__']
def get_vllm_version() -> str: def get_vllm_version() -> str:
# Allow overriding the version. This is useful to build platform-specific # Allow overriding the version. This is useful to build platform-specific
# wheels (e.g. CPU, TPU) without modifying the source. # wheels (e.g. CPU, TPU) without modifying the source.
...@@ -903,6 +997,7 @@ def get_vllm_version() -> str: ...@@ -903,6 +997,7 @@ def get_vllm_version() -> str:
os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version
return get_version(write_to="vllm/_version.py") return get_version(write_to="vllm/_version.py")
if not _is_hip():
version = get_version(write_to="vllm/_version.py") version = get_version(write_to="vllm/_version.py")
sep = "+" if "+" not in version else "." # dev versions might contain + sep = "+" if "+" not in version else "." # dev versions might contain +
...@@ -921,9 +1016,10 @@ def get_vllm_version() -> str: ...@@ -921,9 +1016,10 @@ def get_vllm_version() -> str:
version += f"{sep}cu{cuda_version_str}" version += f"{sep}cu{cuda_version_str}"
elif _is_hip(): elif _is_hip():
# Get the Rocm Version # Get the Rocm Version
rocm_version = get_rocm_version() or torch.version.hip # rocm_version = get_rocm_version() or torch.version.hip
if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION: # if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION:
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" # version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
version = get_version()
elif _is_tpu(): elif _is_tpu():
version += f"{sep}tpu" version += f"{sep}tpu"
elif _is_cpu(): elif _is_cpu():
...@@ -991,7 +1087,7 @@ if _is_cuda() or _is_hip(): ...@@ -991,7 +1087,7 @@ if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
# Optional since this doesn't get built (produce an .so file). This is just # Optional since this doesn't get built (produce an .so file). This is just
# copying the relevant .py files from the source repository. # copying the relevant .py files from the source repository.
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True)) # ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
if _is_hip(): if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
......
...@@ -44,10 +44,10 @@ except ImportError as e: ...@@ -44,10 +44,10 @@ except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
# import custom ops, trigger op registration # import custom ops, trigger op registration
try: # try:
import vllm._rocm_C # noqa: F401 # import vllm._rocm_C # noqa: F401
except ImportError as e: # except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e) # logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm. # Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: list[str] = [] _ROCM_UNSUPPORTED_MODELS: list[str] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment