Unverified Commit 6ccc0bff authored by TJian's avatar TJian Committed by GitHub
Browse files

Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)


Co-authored-by: default avatarPhilipp Moritz <pcmoritz@gmail.com>
Co-authored-by: default avatarAmir Balwel <amoooori04@gmail.com>
Co-authored-by: default avatarroot <kuanfu.liu@akirakan.com>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarkuanfu <kuanfu.liu@embeddedllm.com>
Co-authored-by: default avatarmiloice <17350011+kliuae@users.noreply.github.com>
parent c8e7eb1e
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
+++ common.py 2023-11-28 16:14:19.846233146 +0000
@@ -298,8 +298,8 @@
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
- if device_type == "cuda" and not _built_with_cuda:
- reasons.append("xFormers wasn't build with CUDA support")
+ #if device_type == "cuda" and not _built_with_cuda:
+ # reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@
FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
- if flash_ver_parsed < (2, 3):
- raise ImportError("Requires 2.3 for sliding window support")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+ # if flash_ver_parsed < (2, 3):
+ # raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
-
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -98,8 +98,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None, # rng
)
@@ -127,8 +127,8 @@
softmax_scale,
False,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None,
)
@@ -169,8 +169,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
@@ -193,15 +193,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
return dq, dk, dv
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
@@ -348,7 +348,7 @@
implementation.
"""
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
...@@ -8,27 +8,83 @@ import warnings ...@@ -8,27 +8,83 @@ import warnings
from packaging.version import parse, Version from packaging.version import parse, Version
import setuptools import setuptools
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
MAIN_CUDA_VERSION = "12.1" MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures. # Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
def _is_hip() -> bool:
return torch.version.hip is not None
def _is_cuda() -> bool:
return torch.version.cuda is not None
# Compiler flags. # Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"] CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3? # TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"] NVCC_FLAGS = ["-O2", "-std=c++17"]
if _is_hip():
if ROCM_HOME is None:
raise RuntimeError(
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
NVCC_FLAGS += ["-DUSE_ROCM"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None:
raise RuntimeError( def get_amdgpu_offload_arch():
"Cannot find CUDA_HOME. CUDA must be available to build the package.") command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e
return None
def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
# Check if the command was executed successfully
if result.returncode != 0:
print("Error running 'hipcc --version'")
return None
# Extract the version using a regular expression
match = re.search(r'HIP version: (\S+)', result.stdout)
if match:
# Return the version string
return match.group(1)
else:
print("Could not find HIP version in the output")
return None
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
...@@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]: ...@@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]:
return set() return set()
# Filter out the invalid architectures and print a warning. # Filter out the invalid architectures and print a warning.
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs) arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error. # If none of the specified architectures are valid, raise an error.
if not arch_list: if not arch_list:
raise RuntimeError( raise RuntimeError(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. " f"variable ({env_arch_list}) is supported. "
f"Supported CUDA architectures are: {valid_archs}.") f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list: if invalid_arch_list:
warnings.warn( warnings.warn(
f"Unsupported CUDA architectures ({invalid_arch_list}) are " f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA architectures are: " f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.", f"{valid_archs}.",
stacklevel=2) stacklevel=2)
return arch_list return arch_list
...@@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]: ...@@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]:
# First, check the TORCH_CUDA_ARCH_LIST environment variable. # First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list() compute_capabilities = get_torch_arch_list()
if not compute_capabilities: if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine. # GPUs on the current machine.
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
...@@ -93,69 +151,84 @@ if not compute_capabilities: ...@@ -93,69 +151,84 @@ if not compute_capabilities:
"GPUs with compute capability below 7.0 are not supported.") "GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}") compute_capabilities.add(f"{major}.{minor}")
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if _is_cuda():
if not compute_capabilities: nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
# If no GPU is specified nor available, add all supported architectures if not compute_capabilities:
# based on the NVCC CUDA version. # If no GPU is specified nor available, add all supported architectures
compute_capabilities = SUPPORTED_ARCHS.copy() # based on the NVCC CUDA version.
if nvcc_cuda_version < Version("11.1"): compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
compute_capabilities.remove("8.6") if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError(
"CUDA 11.0 or higher is required to build the package.")
if (nvcc_cuda_version < Version("11.1")
and any(cc.startswith("8.6") for cc in compute_capabilities)):
raise RuntimeError(
"CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"): if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9") if any(cc.startswith("8.9") for cc in compute_capabilities):
compute_capabilities.remove("9.0") # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# Validate the NVCC CUDA version. # the previous versions of CUDA 11 and targeting compute capability 8.0.
if nvcc_cuda_version < Version("11.0"): # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") # instead of 8.9.
if (nvcc_cuda_version < Version("11.1") warnings.warn(
and any(cc.startswith("8.6") for cc in compute_capabilities)): "CUDA 11.8 or higher is required for compute capability 8.9. "
raise RuntimeError( "Targeting compute capability 8.0 instead.",
"CUDA 11.1 or higher is required for compute capability 8.6.") stacklevel=2)
if nvcc_cuda_version < Version("11.8"): compute_capabilities = set(cc for cc in compute_capabilities
if any(cc.startswith("8.9") for cc in compute_capabilities): if not cc.startswith("8.9"))
# CUDA 11.8 is required to generate the code targeting compute capability 8.9. compute_capabilities.add("8.0+PTX")
# However, GPUs with compute capability 8.9 can also run the code generated by if any(cc.startswith("9.0") for cc in compute_capabilities):
# the previous versions of CUDA 11 and targeting compute capability 8.0. raise RuntimeError(
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 "CUDA 11.8 or higher is required for compute capability 9.0.")
# instead of 8.9.
warnings.warn( # Add target compute capabilities to NVCC flags.
"CUDA 11.8 or higher is required for compute capability 8.9. " for capability in compute_capabilities:
"Targeting compute capability 8.0 instead.", num = capability[0] + capability[2]
stacklevel=2) NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
compute_capabilities = set(cc for cc in compute_capabilities if capability.endswith("+PTX"):
if not cc.startswith("8.9")) NVCC_FLAGS += [
compute_capabilities.add("8.0+PTX") "-gencode", f"arch=compute_{num},code=compute_{num}"
if any(cc.startswith("9.0") for cc in compute_capabilities): ]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
elif _is_hip():
amd_arch = get_amdgpu_offload_arch()
if amd_arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0.") f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {amd_arch}")
ext_modules = []
# Add target compute capabilities to NVCC flags. vllm_extension_sources = [
for capability in compute_capabilities: "csrc/cache_kernels.cu",
num = capability[0] + capability[2] "csrc/attention/attention_kernels.cu",
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] "csrc/pos_encoding_kernels.cu",
if capability.endswith("+PTX"): "csrc/activation_kernels.cu",
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
]
# Use NVCC threads to parallelize the build. if _is_cuda():
if nvcc_cuda_version >= Version("11.2"): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
ext_modules = []
vllm_extension = CUDAExtension( vllm_extension = CUDAExtension(
name="vllm._C", name="vllm._C",
sources=[ sources=vllm_extension_sources,
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
],
extra_compile_args={ extra_compile_args={
"cxx": CXX_FLAGS, "cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS, "nvcc": NVCC_FLAGS,
...@@ -183,10 +256,19 @@ def find_version(filepath: str) -> str: ...@@ -183,10 +256,19 @@ def find_version(filepath: str) -> str:
def get_vllm_version() -> str: def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py")) version = find_version(get_path("vllm", "__init__.py"))
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION: if _is_hip():
cuda_version_str = cuda_version.replace(".", "")[:3] # Get the HIP version
version += f"+cu{cuda_version_str}" hipcc_version = get_hipcc_rocm_version()
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
return version return version
...@@ -201,8 +283,12 @@ def read_readme() -> str: ...@@ -201,8 +283,12 @@ def read_readme() -> str:
def get_requirements() -> List[str]: def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt.""" """Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f: if _is_hip():
requirements = f.read().strip().split("\n") with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements return requirements
......
...@@ -6,7 +6,7 @@ from transformers import PretrainedConfig ...@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory from vllm.utils import get_cpu_memory, is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -98,12 +98,27 @@ class ModelConfig: ...@@ -98,12 +98,27 @@ class ModelConfig:
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
if load_format not in [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy"
]: ]
rocm_not_supported_load_format = ["safetensors"]
if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip():
if load_format in ["safetensors"]:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
self.load_format = load_format self.load_format = load_format
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
...@@ -116,6 +131,7 @@ class ModelConfig: ...@@ -116,6 +131,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"] supported_quantization = ["awq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -137,6 +153,11 @@ class ModelConfig: ...@@ -137,6 +153,11 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"Unknown quantization method: {self.quantization}. Must " f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.") f"be one of {supported_quantization}.")
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully " logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.") "non-quantized models.")
...@@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
} }
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype( def _get_and_verify_dtype(
config: PretrainedConfig, config: PretrainedConfig,
...@@ -393,6 +416,14 @@ def _get_and_verify_dtype( ...@@ -393,6 +416,14 @@ def _get_and_verify_dtype(
else: else:
raise ValueError(f"Unknown dtype: {dtype}") raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype. # Verify the dtype.
if torch_dtype != config_dtype: if torch_dtype != config_dtype:
if torch_dtype == torch.float32: if torch_dtype == torch.float32:
......
...@@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING ...@@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -73,7 +74,12 @@ def initialize_cluster( ...@@ -73,7 +74,12 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
ray.init(address=ray_address, ignore_reinit_error=True) if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.
......
...@@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, ...@@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops from vllm._C import ops
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...@@ -160,6 +161,8 @@ class PagedAttention(nn.Module): ...@@ -160,6 +161,8 @@ class PagedAttention(nn.Module):
attn_bias=input_metadata.attn_bias, attn_bias=input_metadata.attn_bias,
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = out.view_as(query) output = out.view_as(query)
else: else:
......
...@@ -7,6 +7,7 @@ from vllm._C import ops ...@@ -7,6 +7,7 @@ from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig): class SqueezeLLMConfig(QuantizationConfig):
...@@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
lookup_table = weights["lookup_table"] lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized. if is_hip():
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
......
...@@ -10,6 +10,10 @@ from vllm.config import ModelConfig ...@@ -10,6 +10,10 @@ from vllm.config import ModelConfig
from vllm.model_executor.models import * from vllm.model_executor.models import *
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
...@@ -39,6 +43,18 @@ _MODEL_REGISTRY = { ...@@ -39,6 +43,18 @@ _MODEL_REGISTRY = {
"YiForCausalLM": YiForCausalLM, "YiForCausalLM": YiForCausalLM,
} }
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
...@@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
if arch in _MODEL_REGISTRY: if arch in _MODEL_REGISTRY:
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
......
...@@ -27,10 +27,14 @@ class Counter: ...@@ -27,10 +27,14 @@ class Counter:
self.counter = 0 self.counter = 0
def is_hip() -> bool:
return torch.version.hip is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute( max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem) return int(max_shared_mem)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment