Unverified Commit 5f063a80 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[bugfix] add supports_v1 platform interface (#15417)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 5d8e1c92
......@@ -1666,9 +1666,8 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False
# No support for device type other than CUDA, AMD (experiemntal) or
# TPU (experimental) so far.
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
# Platforms must decide if they can support v1 for this model
if not current_platform.supports_v1(model_config=model_config):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
......
......@@ -20,8 +20,9 @@ from vllm.utils import import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
......@@ -303,6 +304,10 @@ class CudaPlatformBase(Platform):
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
......@@ -12,9 +12,10 @@ import torch
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
FlexibleArgumentParser = None
......@@ -371,6 +372,13 @@ class Platform:
or parallel_config.distributed_executor_backend
== "external_launcher")
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return False
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
......@@ -12,8 +12,9 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
......@@ -249,3 +250,8 @@ class RocmPlatform(Platform):
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True
......@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
......@@ -127,3 +128,8 @@ class TpuPlatform(Platform):
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment