Unverified Commit 6ccc0bff authored by TJian's avatar TJian Committed by GitHub
Browse files

Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)


Co-authored-by: default avatarPhilipp Moritz <pcmoritz@gmail.com>
Co-authored-by: default avatarAmir Balwel <amoooori04@gmail.com>
Co-authored-by: default avatarroot <kuanfu.liu@akirakan.com>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarkuanfu <kuanfu.liu@embeddedllm.com>
Co-authored-by: default avatarmiloice <17350011+kliuae@users.noreply.github.com>
parent c8e7eb1e
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
+++ common.py 2023-11-28 16:14:19.846233146 +0000
@@ -298,8 +298,8 @@
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
- if device_type == "cuda" and not _built_with_cuda:
- reasons.append("xFormers wasn't build with CUDA support")
+ #if device_type == "cuda" and not _built_with_cuda:
+ # reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@
FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
- if flash_ver_parsed < (2, 3):
- raise ImportError("Requires 2.3 for sliding window support")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+ # if flash_ver_parsed < (2, 3):
+ # raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
-
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -98,8 +98,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None, # rng
)
@@ -127,8 +127,8 @@
softmax_scale,
False,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None,
)
@@ -169,8 +169,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
@@ -193,15 +193,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
return dq, dk, dv
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
@@ -348,7 +348,7 @@
implementation.
"""
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
......@@ -8,27 +8,83 @@ import warnings
from packaging.version import parse, Version
import setuptools
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__)
MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
def _is_hip() -> bool:
return torch.version.hip is not None
def _is_cuda() -> bool:
return torch.version.cuda is not None
# Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"]
if _is_hip():
if ROCM_HOME is None:
raise RuntimeError(
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
NVCC_FLAGS += ["-DUSE_ROCM"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_amdgpu_offload_arch():
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e
return None
def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
# Check if the command was executed successfully
if result.returncode != 0:
print("Error running 'hipcc --version'")
return None
# Extract the version using a regular expression
match = re.search(r'HIP version: (\S+)', result.stdout)
if match:
# Return the version string
return match.group(1)
else:
print("Could not find HIP version in the output")
return None
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
......@@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]:
return set()
# Filter out the invalid architectures and print a warning.
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA architectures are: {valid_archs}.")
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA architectures are: "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list
......@@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if not compute_capabilities:
if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
......@@ -93,69 +151,84 @@ if not compute_capabilities:
"GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = SUPPORTED_ARCHS.copy()
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError(
"CUDA 11.0 or higher is required to build the package.")
if (nvcc_cuda_version < Version("11.1")
and any(cc.startswith("8.6") for cc in compute_capabilities)):
raise RuntimeError(
"CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if (nvcc_cuda_version < Version("11.1")
and any(cc.startswith("8.6") for cc in compute_capabilities)):
raise RuntimeError(
"CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"):
if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
warnings.warn(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead.",
stacklevel=2)
compute_capabilities = set(cc for cc in compute_capabilities
if not cc.startswith("8.9"))
compute_capabilities.add("8.0+PTX")
if any(cc.startswith("9.0") for cc in compute_capabilities):
if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
warnings.warn(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead.",
stacklevel=2)
compute_capabilities = set(cc for cc in compute_capabilities
if not cc.startswith("8.9"))
compute_capabilities.add("8.0+PTX")
if any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0.")
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += [
"-gencode", f"arch=compute_{num},code=compute_{num}"
]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
elif _is_hip():
amd_arch = get_amdgpu_offload_arch()
if amd_arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0.")
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {amd_arch}")
ext_modules = []
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
vllm_extension_sources = [
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
ext_modules = []
vllm_extension = CUDAExtension(
name="vllm._C",
sources=[
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
],
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
......@@ -183,10 +256,19 @@ def find_version(filepath: str) -> str:
def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
if _is_hip():
# Get the HIP version
hipcc_version = get_hipcc_rocm_version()
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
return version
......@@ -201,8 +283,12 @@ def read_readme() -> str:
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements
......
......@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory
from vllm.utils import get_cpu_memory, is_hip
logger = init_logger(__name__)
......@@ -98,12 +98,27 @@ class ModelConfig:
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in [
"auto", "pt", "safetensors", "npcache", "dummy"
]:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = ["safetensors"]
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip():
if load_format in ["safetensors"]:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
......@@ -116,6 +131,7 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......@@ -137,6 +153,11 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
......@@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype(
config: PretrainedConfig,
......@@ -393,6 +416,14 @@ def _get_and_verify_dtype(
else:
raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
......
......@@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
......@@ -73,7 +74,12 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_address, ignore_reinit_error=True)
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.
......
......@@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
......@@ -160,6 +161,8 @@ class PagedAttention(nn.Module):
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
......
......@@ -7,6 +7,7 @@ from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig):
......@@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out = out + bias
......
......@@ -10,6 +10,10 @@ from vllm.config import ModelConfig
from vllm.model_executor.models import *
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
......@@ -39,6 +43,18 @@ _MODEL_REGISTRY = {
"YiForCausalLM": YiForCausalLM,
}
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
......@@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
......
......@@ -27,10 +27,14 @@ class Counter:
self.counter = 0
def is_hip() -> bool:
return torch.version.hip is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment