Unverified Commit f6027b28 authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[1/N][Platform] Cleanup useless function (#26982)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent ab3e8004
...@@ -9,6 +9,7 @@ Note: these tests will only pass on L4 GPU. ...@@ -9,6 +9,7 @@ Note: these tests will only pass on L4 GPU.
import pytest import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
...@@ -69,8 +70,10 @@ def test_models( ...@@ -69,8 +70,10 @@ def test_models(
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): if not flash_attn_supports_fp8():
pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") pytest.skip(
f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention."
)
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv("TOKENIZERS_PARALLELISM", "true")
......
...@@ -356,10 +356,6 @@ def test_compressed_tensors_fp8(vllm_runner): ...@@ -356,10 +356,6 @@ def test_compressed_tensors_fp8(vllm_runner):
assert output assert output
@pytest.mark.skipif(
not current_platform.is_kv_cache_dtype_supported("fp8", None),
reason="FP8 KV cache is not supported on this device.",
)
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
) )
......
...@@ -23,7 +23,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum ...@@ -23,7 +23,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None _Backend = None
...@@ -457,49 +457,6 @@ class CudaPlatformBase(Platform): ...@@ -457,49 +457,6 @@ class CudaPlatformBase(Platform):
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
attention_backend = envs.VLLM_ATTENTION_BACKEND
supported = False
if model_config is not None and model_config.use_mla:
# Default to CutlassMLA for blackwell,
# FlashMLA otherwise
if attention_backend is None:
if cls.is_device_capability(100):
attention_backend = "CUTLASS_MLA"
else:
attention_backend = "FLASHMLA"
# Only FlashMLA and CUTLASS_MLA support fp8
if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
supported = True
else:
supported = not fp8_attention
else:
# Default to FlashAttention
if attention_backend is None:
attention_backend = "FLASH_ATTN"
# All Blackwell backends support fp8
if cls.is_device_capability(100):
supported = True
elif attention_backend == "FLASH_ATTN":
if fp8_attention:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
else:
supported = True
elif attention_backend == "FLASHINFER":
supported = True
elif attention_backend == "TRITON_ATTN":
supported = cls.supports_fp8()
return supported
@classmethod @classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102 if dtype == torch.bfloat16: # noqa: SIM102
......
...@@ -7,28 +7,23 @@ import platform ...@@ -7,28 +7,23 @@ import platform
import random import random
import sys import sys
from datetime import timedelta from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, Any, NamedTuple from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np import numpy as np
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
else: else:
_Backend = object
ModelConfig = object
VllmConfig = object
PoolingParams = object
SamplingParams = object
FlexibleArgumentParser = object FlexibleArgumentParser = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -36,7 +31,7 @@ logger = init_logger(__name__) ...@@ -36,7 +31,7 @@ logger = init_logger(__name__)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(platform.uname()).lower()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
...@@ -178,7 +173,8 @@ class Platform: ...@@ -178,7 +173,8 @@ class Platform:
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
# Import _Backend here to avoid circular import.
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
...@@ -186,7 +182,7 @@ class Platform: ...@@ -186,7 +182,7 @@ class Platform:
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: _Backend, selected_backend: "_Backend",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
...@@ -317,7 +313,7 @@ class Platform: ...@@ -317,7 +313,7 @@ class Platform:
pass pass
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
""" """
Check and update the configuration for the current platform. Check and update the configuration for the current platform.
...@@ -498,9 +494,9 @@ class Platform: ...@@ -498,9 +494,9 @@ class Platform:
@classmethod @classmethod
def validate_request( def validate_request(
cls, cls,
prompt: PromptType, prompt: "PromptType",
params: SamplingParams | PoolingParams, params: "SamplingParams | PoolingParams",
processed_inputs: ProcessorInputs, processed_inputs: "ProcessorInputs",
) -> None: ) -> None:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
...@@ -543,25 +539,16 @@ class Platform: ...@@ -543,25 +539,16 @@ class Platform:
def stateless_init_device_torch_dist_pg( def stateless_init_device_torch_dist_pg(
cls, cls,
backend: str, backend: str,
prefix_store: PrefixStore, prefix_store: "PrefixStore",
group_rank: int, group_rank: int,
group_size: int, group_size: int,
timeout: timedelta, timeout: timedelta,
) -> ProcessGroup: ) -> "ProcessGroup":
""" """
Init platform-specific torch distributed process group. Init platform-specific torch distributed process group.
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: ModelConfig
) -> bool:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""
return False
@classmethod @classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
""" """
......
...@@ -15,7 +15,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum ...@@ -15,7 +15,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None _Backend = None
...@@ -474,12 +474,6 @@ class RocmPlatform(Platform): ...@@ -474,12 +474,6 @@ class RocmPlatform(Platform):
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
return True
@classmethod @classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102 if dtype == torch.bfloat16: # noqa: SIM102
......
...@@ -222,12 +222,6 @@ class TpuPlatform(Platform): ...@@ -222,12 +222,6 @@ class TpuPlatform(Platform):
): ):
raise ValueError("Torch XLA does not support per-request seed.") raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
return True
@classmethod @classmethod
@torch.compile(backend="openxla") @torch.compile(backend="openxla")
def insert_blocks_to_device( def insert_blocks_to_device(
......
...@@ -86,22 +86,6 @@ class XPUPlatform(Platform): ...@@ -86,22 +86,6 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend on V1 engine.") logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
"""
Check if the kv_cache_dtype is supported.
XPU only support fp8 kv cache with triton backend.
"""
if (
envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN"
):
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
return False
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment