Unverified Commit 8c1fb507 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Platform][Refactor] Extract func `get_default_attn_backend` to `Platform` (#10358)


Signed-off-by: default avatarMengqing Cao <cmq0113@163.com>
parent 7eb719df
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
...@@ -19,24 +20,26 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -19,24 +20,26 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name) override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform.is_cpu", with patch("vllm.attention.selector.current_platform",
return_value=True): cpu.CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "TORCH_SDPA" assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform.is_rocm", with patch("vllm.attention.selector.current_platform",
return_value=True): rocm.RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.current_platform.is_openvino", with patch("vllm.attention.selector.current_platform",
return_value=True): openvino.OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "OPENVINO" assert backend.name == "OPENVINO"
else: else:
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == name assert backend.name == name
......
import enum
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
...@@ -9,26 +8,12 @@ import torch ...@@ -9,26 +8,12 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__) logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend: def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None assert backend_name is not None
...@@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int, ...@@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None: if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if current_platform.is_cpu(): # get device-specific default attn_backend
if selected_backend != _Backend.TORCH_SDPA: default_backend = current_platform.get_default_attn_backend(
logger.info("Cannot use %s backend on CPU.", selected_backend) selected_backend)
return _Backend.TORCH_SDPA if default_backend is not None:
return default_backend
if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if current_platform.is_hpu():
return _Backend.HPU_ATTN
if use_v1: if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1 return _Backend.FLASH_ATTN_VLLM_V1
......
...@@ -13,7 +13,6 @@ from torch.nn import functional as F ...@@ -13,7 +13,6 @@ from torch.nn import functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...@@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
......
...@@ -39,7 +39,6 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( ...@@ -39,7 +39,6 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize) make_batched_images, make_batched_videos, smart_resize)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -65,6 +64,7 @@ from vllm.multimodal.image import cached_get_image_processor ...@@ -65,6 +64,7 @@ from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs) MultiModalKwargs)
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
......
...@@ -9,13 +9,13 @@ from torch.func import functional_call ...@@ -9,13 +9,13 @@ from torch.func import functional_call
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum, from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
......
from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform current_platform: Platform
......
...@@ -5,7 +5,9 @@ import torch ...@@ -5,7 +5,9 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -22,6 +24,12 @@ class CpuPlatform(Platform): ...@@ -22,6 +24,12 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "cpu" return "cpu"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total return psutil.virtual_memory().total
......
import torch import torch
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum, _Backend
class HpuPlatform(Platform): class HpuPlatform(Platform):
_enum = PlatformEnum.HPU _enum = PlatformEnum.HPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN
@staticmethod @staticmethod
def inference_mode(): def inference_mode():
return torch.no_grad() return torch.no_grad()
...@@ -11,6 +11,20 @@ else: ...@@ -11,6 +11,20 @@ else:
VllmConfig = None VllmConfig = None
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
CUDA = enum.auto() CUDA = enum.auto()
ROCM = enum.auto() ROCM = enum.auto()
...@@ -71,6 +85,11 @@ class Platform: ...@@ -71,6 +85,11 @@ class Platform:
"""Stateless version of :func:`torch.cuda.is_available`.""" """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None
@classmethod @classmethod
def get_device_capability( def get_device_capability(
cls, cls,
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -11,6 +11,12 @@ logger = init_logger(__name__) ...@@ -11,6 +11,12 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform): class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO _enum = PlatformEnum.OPENVINO
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
@classmethod @classmethod
def get_device_name(self, device_id: int = 0) -> str: def get_device_name(self, device_id: int = 0) -> str:
return "openvino" return "openvino"
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: ...@@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
@classmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
......
...@@ -3,17 +3,27 @@ from typing import TYPE_CHECKING ...@@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
import torch import torch
from .interface import Platform, PlatformEnum from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
logger = init_logger(__name__)
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
......
import torch import torch
from .interface import DeviceCapability, Platform, PlatformEnum from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
class XPUPlatform(Platform): class XPUPlatform(Platform):
_enum = PlatformEnum.XPU _enum = PlatformEnum.XPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
@staticmethod @staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability: def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability( major, minor, *_ = torch.xpu.get_device_capability(
......
...@@ -8,7 +8,7 @@ import torch.distributed ...@@ -8,7 +8,7 @@ import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, from vllm.attention.selector import (get_env_variable_attn_backend,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -18,6 +18,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -18,6 +18,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment