Unverified Commit 482045ee authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[hardware][misc] introduce platform abstraction (#6080)

parent 9d6a8daa
...@@ -8,13 +8,13 @@ import pytest ...@@ -8,13 +8,13 @@ import pytest
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
capability = get_device_capability_stateless() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
......
import torch import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
def is_quant_method_supported(quant_method: str) -> bool: def is_quant_method_supported(quant_method: str) -> bool:
...@@ -9,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool: ...@@ -9,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool:
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return False return False
capability = get_device_capability_stateless() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
return (capability >= return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability()) QUANTIZATION_METHODS[quant_method].get_min_capability())
...@@ -2,13 +2,14 @@ import math ...@@ -2,13 +2,14 @@ import math
import torch import torch
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and get_device_capability_stateless()[0] >= 8) and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE: if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
if triton.__version__ >= "2.1.0": if triton.__version__ >= "2.1.0":
...@@ -685,7 +685,7 @@ if triton.__version__ >= "2.1.0": ...@@ -685,7 +685,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
cap = get_device_capability_stateless() cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64 BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
......
...@@ -5,14 +5,14 @@ from typing import Optional ...@@ -5,14 +5,14 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
def _check_punica_support(): def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return return
if get_device_capability_stateless() < (8, 0): if current_platform.get_device_capability() < (8, 0):
raise ImportError( raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") "punica LoRA kernels require compute capability >= 8.0")
else: else:
......
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match) find_first_name_or_class_match)
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
...@@ -85,7 +85,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -85,7 +85,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return [] return []
def _check_gptq_and_marlin_can_run(self): def _check_gptq_and_marlin_can_run(self):
capability = get_device_capability_stateless() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < 80: if capability < 80:
raise RuntimeError("The quantization config is not supported for ", raise RuntimeError("The quantization config is not supported for ",
......
...@@ -12,7 +12,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase ...@@ -12,7 +12,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import get_device_capability_stateless, print_warning_once from vllm.platforms import current_platform
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -20,7 +21,7 @@ logger = init_logger(__name__) ...@@ -20,7 +21,7 @@ logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
capability = get_device_capability_stateless() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
return ops.cutlass_scaled_mm_supports_fp8(capability) return ops.cutlass_scaled_mm_supports_fp8(capability)
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -173,7 +173,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -173,7 +173,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return False return False
# If the capability of the device is too low, cannot convert. # If the capability of the device is too low, cannot convert.
major, minor = get_device_capability_stateless() major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor device_capability = major * 10 + minor
if device_capability < cls.get_min_capability(): if device_capability < cls.get_min_capability():
return False return False
......
...@@ -12,9 +12,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import ( ...@@ -12,9 +12,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm, marlin_scale_perm, marlin_scale_perm_single) marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights) get_pack_factor, quantize_weights, sort_weights)
from vllm.utils import get_device_capability_stateless from vllm.platforms import current_platform
__cuda_arch = get_device_capability_stateless() __cuda_arch = current_platform.get_device_capability()
MARLIN_TILE = 16 MARLIN_TILE = 16
......
...@@ -35,7 +35,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -35,7 +35,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision) supports_vision)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import get_device_capability_stateless, is_tpu from vllm.platforms import current_platform
from vllm.utils import is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,7 +47,7 @@ def _get_quantization_config( ...@@ -46,7 +47,7 @@ def _get_quantization_config(
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = get_device_capability_stateless() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
raise ValueError( raise ValueError(
......
from typing import Optional
import torch
from .interface import Platform, PlatformEnum
current_platform: Optional[Platform]
if torch.version.cuda is not None:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
else:
current_platform = None
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""
from functools import lru_cache, wraps
from typing import Tuple
import pynvml
from .interface import Platform, PlatformEnum
def with_nvml_context(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
@staticmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
import enum
from typing import Tuple
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
class Platform:
_enum: PlatformEnum
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
from functools import lru_cache
from typing import Tuple
import torch
from .interface import Platform, PlatformEnum
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@staticmethod
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)
...@@ -866,13 +866,6 @@ def is_full_nvlink(device_ids: List[int]) -> bool: ...@@ -866,13 +866,6 @@ def is_full_nvlink(device_ids: List[int]) -> bool:
return True return True
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
#From: https://stackoverflow.com/a/4104188/2749989 #From: https://stackoverflow.com/a/4104188/2749989
def run_once(f): def run_once(f):
......
...@@ -15,8 +15,8 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -15,8 +15,8 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_device_capability_stateless
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
...@@ -333,7 +333,7 @@ def init_worker_distributed_environment( ...@@ -333,7 +333,7 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:
compute_capability = get_device_capability_stateless() compute_capability = current_platform.get_device_capability()
if compute_capability[0] < 8: if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name() gpu_name = torch.cuda.get_device_name()
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment