Unverified Commit 2aaa4238 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Move Backend enum into registry (#25893)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ad2d7880
......@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
......@@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
......@@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
......@@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......
......@@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
......@@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend
......
......@@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
logger = init_logger(__name__)
......
......@@ -9,7 +9,6 @@ from vllm import envs
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum
logger = logging.getLogger(__name__)
......
......@@ -15,13 +15,15 @@ import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig
else:
_Backend = None
VllmConfig = None
......@@ -90,10 +92,11 @@ class CpuPlatform(Platform):
return "cpu"
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str:
from vllm.attention.backends.registry import _Backend
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
......
......@@ -20,10 +20,13 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__)
......@@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
......@@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_mla:
if not use_v1:
raise RuntimeError(
......
......@@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
_Backend = None
ModelConfig = None
VllmConfig = None
LoRARequest = None
......@@ -38,30 +40,6 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto()
IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
......@@ -187,11 +165,12 @@ class Platform:
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str:
......
......@@ -14,10 +14,13 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__)
......@@ -182,7 +185,8 @@ class RocmPlatform(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models.
......@@ -196,6 +200,7 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")
......
......@@ -11,9 +11,10 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else:
......@@ -21,6 +22,7 @@ else:
ModelConfig = None
VllmConfig = None
PoolingParams = None
_Backend = None
logger = init_logger(__name__)
......@@ -46,10 +48,11 @@ class TpuPlatform(Platform):
]
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on TPU.")
......
......@@ -10,13 +10,15 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
_Backend = None
logger = init_logger(__name__)
......@@ -33,10 +35,11 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on XPU.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment