"docs/vscode:/vscode.git/clone" did not exist on "19d98e0c7db96713f0e2201649159431177a56e2"
Unverified Commit 2aaa4238 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Move Backend enum into registry (#25893)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ad2d7880
...@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize ...@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor) Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
...@@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( ...@@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize) smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
...@@ -13,6 +13,7 @@ from torch.nn import functional as F ...@@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
......
...@@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol, ...@@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -9,7 +9,6 @@ from vllm import envs ...@@ -9,7 +9,6 @@ from vllm import envs
from vllm.plugins import load_plugins_by_group from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl from vllm.utils import resolve_obj_by_qualname, supports_xccl
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -15,13 +15,15 @@ import torch ...@@ -15,13 +15,15 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None
VllmConfig = None VllmConfig = None
...@@ -90,10 +92,11 @@ class CpuPlatform(Platform): ...@@ -90,10 +92,11 @@ class CpuPlatform(Platform):
return "cpu" return "cpu"
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
from vllm.attention.backends.registry import _Backend
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
......
...@@ -20,10 +20,13 @@ import vllm.envs as envs ...@@ -20,10 +20,13 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, import_pynvml from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -202,7 +205,8 @@ class CudaPlatformBase(Platform): ...@@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
# For Blackwell GPUs, force TORCH_SDPA for now. # For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
...@@ -230,6 +234,7 @@ class CudaPlatformBase(Platform): ...@@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(
......
...@@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType ...@@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
else: else:
_Backend = None
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
LoRARequest = None LoRARequest = None
...@@ -38,30 +40,6 @@ def in_wsl() -> bool: ...@@ -38,30 +40,6 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto()
IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
CUDA = enum.auto() CUDA = enum.auto()
ROCM = enum.auto() ROCM = enum.auto()
...@@ -187,11 +165,12 @@ class Platform: ...@@ -187,11 +165,12 @@ class Platform:
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
......
...@@ -14,10 +14,13 @@ import vllm.envs as envs ...@@ -14,10 +14,13 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -182,7 +185,8 @@ class RocmPlatform(Platform): ...@@ -182,7 +185,8 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()): and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
...@@ -196,6 +200,7 @@ class RocmPlatform(Platform): ...@@ -196,6 +200,7 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on ROCm.") "Sparse Attention is not supported on ROCm.")
......
...@@ -11,9 +11,10 @@ from vllm.logger import init_logger ...@@ -11,9 +11,10 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
else: else:
...@@ -21,6 +22,7 @@ else: ...@@ -21,6 +22,7 @@ else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,10 +48,11 @@ class TpuPlatform(Platform): ...@@ -46,10 +48,11 @@ class TpuPlatform(Platform):
] ]
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on TPU.") "Sparse Attention is not supported on TPU.")
......
...@@ -10,13 +10,15 @@ import vllm.envs as envs ...@@ -10,13 +10,15 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else: else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,10 +35,11 @@ class XPUPlatform(Platform): ...@@ -33,10 +35,11 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK" device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse) -> str: has_sink: bool, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on XPU.") "Sparse Attention is not supported on XPU.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment