Unverified Commit b30dfa03 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor CUDA attention backend selection logic (#24794)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2e78150d
...@@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass ...@@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
else: else:
_Backend = Any AttentionBackendEnum = Any
@dataclass @dataclass
...@@ -125,10 +125,10 @@ class MultiModalConfig: ...@@ -125,10 +125,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`). DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.""" `"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: _Backend | None = None mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when """Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using """Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string.""" --chat-template-content-format=string."""
...@@ -167,26 +167,16 @@ class MultiModalConfig: ...@@ -167,26 +167,16 @@ class MultiModalConfig:
@field_validator("mm_encoder_attn_backend", mode="before") @field_validator("mm_encoder_attn_backend", mode="before")
@classmethod @classmethod
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: def _validate_mm_encoder_attn_backend(
from vllm.attention.backends.registry import ( cls, value: str | AttentionBackendEnum | None
_Backend as BackendEnum, ) -> AttentionBackendEnum | None:
) if value is None or isinstance(value, AttentionBackendEnum):
from vllm.attention.backends.registry import (
backend_name_to_enum,
)
if value is None or isinstance(value, BackendEnum):
return value return value
if isinstance(value, str): assert isinstance(value, str), (
candidate = backend_name_to_enum(value.upper()) "mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
if candidate is not None:
return candidate
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
raise ValueError(
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
) )
return AttentionBackendEnum[value.upper()]
@model_validator(mode="after") @model_validator(mode="after")
def _validate_multimodal_config(self): def _validate_multimodal_config(self):
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...@@ -876,9 +876,9 @@ class NixlConnectorWorker: ...@@ -876,9 +876,9 @@ class NixlConnectorWorker:
use_mla=self.use_mla, use_mla=self.use_mla,
) )
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name) attn_backend = AttentionBackendEnum[self.backend_name]
self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
self._use_pallas = attn_backend == _Backend.PALLAS self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
......
...@@ -32,7 +32,7 @@ from pydantic.fields import FieldInfo ...@@ -32,7 +32,7 @@ from pydantic.fields import FieldInfo
from typing_extensions import TypeIs, deprecated from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
...@@ -462,7 +462,7 @@ class EngineArgs: ...@@ -462,7 +462,7 @@ class EngineArgs:
MultiModalConfig.mm_shm_cache_max_object_size_mb MultiModalConfig.mm_shm_cache_max_object_size_mb
) )
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: _Backend | str | None = ( mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
MultiModalConfig.mm_encoder_attn_backend MultiModalConfig.mm_encoder_attn_backend
) )
io_processor_plugin: str | None = None io_processor_plugin: str | None = None
......
...@@ -626,14 +626,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -626,14 +626,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA
# - "FLASHINFER_MLA": use FlashInfer for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA
# - "CUTLASS_MLA": use CUTLASS for MLA # - "CUTLASS_MLA": use CUTLASS for MLA
# All possible options loaded dynamically from _Backend enum # All possible options loaded dynamically from AttentionBackendEnum
"VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND": env_with_choices(
"VLLM_ATTENTION_BACKEND", "VLLM_ATTENTION_BACKEND",
None, None,
lambda: list( lambda: list(
__import__( __import__(
"vllm.attention.backends.registry", fromlist=["_Backend"] "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"]
)._Backend.__members__.keys() ).AttentionBackendEnum.__members__.keys()
), ),
), ),
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
......
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
...@@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module): ...@@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module): ...@@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module):
) )
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}" f"Unsupported vision attention backend: {self.attn_backend}"
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
...@@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module): ...@@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module):
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1]) s = int(cu_seqlens[i - 1])
...@@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module): ...@@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3) out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i) outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module): ...@@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
...@@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module): ...@@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module):
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers # Keep blocks for compatibility with other vision towers
num_layers = ( num_layers = (
...@@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -36,7 +36,7 @@ import torch.nn.functional as F ...@@ -36,7 +36,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
...@@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now." f"Ernie45-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
...@@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
...@@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
...@@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import ( ...@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
...@@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module): ...@@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now." f"GLM-4V does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
...@@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module): ...@@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
...@@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module):
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature ...@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
...@@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module): ...@@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now." f"Keye-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
...@@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale, softmax_scale=self.scale,
) )
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module): ...@@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
...@@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module): ...@@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module): ...@@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
model_type = config.model_type model_type = config.model_type
if model_type == "siglip2_navit": if model_type == "siglip2_navit":
......
...@@ -31,7 +31,7 @@ from transformers.modeling_outputs import ( ...@@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
) )
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
...@@ -580,8 +580,8 @@ class SiglipAttention(nn.Module): ...@@ -580,8 +580,8 @@ class SiglipAttention(nn.Module):
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -621,8 +621,8 @@ class SiglipAttention(nn.Module): ...@@ -621,8 +621,8 @@ class SiglipAttention(nn.Module):
) )
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -680,10 +680,10 @@ class SiglipAttention(nn.Module): ...@@ -680,10 +680,10 @@ class SiglipAttention(nn.Module):
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa, self.use_upstream_fa,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1] start_idx = cu_seqlens[i - 1]
...@@ -702,7 +702,7 @@ class SiglipAttention(nn.Module): ...@@ -702,7 +702,7 @@ class SiglipAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None: if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.") raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
...@@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module): ...@@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
): ):
super().__init__() super().__init__()
...@@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module): ...@@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module): ...@@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module):
) )
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()): } and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True self.use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now." f"PaddleOCR-VL does not support {self.attn_backend} backend now."
...@@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module): ...@@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module):
max_seqlen = None max_seqlen = None
seqlens = None seqlens = None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module): ...@@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module):
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
......
...@@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ...@@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
...@@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module):
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used # On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True self.use_upstream_fa = True
if current_platform.is_xpu(): if current_platform.is_xpu():
self.use_upstream_fa = False self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module):
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa, self.use_upstream_fa,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module):
v, v,
cu_seqlens, cu_seqlens,
) )
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
...@@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
...@@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import ( ...@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
...@@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module): ...@@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now." f"Qwen2-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
...@@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( ...@@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
) )
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
...@@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( ...@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
) )
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module):
) )
use_upstream_fa = False use_upstream_fa = False
if ( if (
self.attn_backend != _Backend.FLASH_ATTN self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != _Backend.ROCM_AITER_FA and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype()) and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now." f"Qwen3-VL does not support {self.attn_backend} backend now."
...@@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module):
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
......
...@@ -12,7 +12,7 @@ from torch.nn import functional as F ...@@ -12,7 +12,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module): ...@@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module): ...@@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
...@@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module): ...@@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module):
attn_output = self.flash_attn_varlen_func( attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(seq_length, -1) ).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1 batch_size = cu_seqlens.shape[0] - 1
outputs = [] outputs = []
...@@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module): ...@@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module): ...@@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
......
...@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar ...@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -83,8 +83,8 @@ def get_vit_attn_backend( ...@@ -83,8 +83,8 @@ def get_vit_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> _Backend: ) -> AttentionBackendEnum:
""" """
Get the available attention backend for Vision Transformer. Get the available attention backend for Vision Transformer.
""" """
...@@ -94,7 +94,7 @@ def get_vit_attn_backend( ...@@ -94,7 +94,7 @@ def get_vit_attn_backend(
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: _Backend | None = get_env_variable_attn_backend() selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
if selected_backend is not None: if selected_backend is not None:
return selected_backend return selected_backend
......
...@@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum ...@@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
...@@ -127,7 +127,7 @@ class CpuPlatform(Platform): ...@@ -127,7 +127,7 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
...@@ -137,9 +137,9 @@ class CpuPlatform(Platform): ...@@ -137,9 +137,9 @@ class CpuPlatform(Platform):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
) -> str: ) -> str:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
...@@ -148,7 +148,7 @@ class CpuPlatform(Platform): ...@@ -148,7 +148,7 @@ class CpuPlatform(Platform):
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" return AttentionBackendEnum.TORCH_SDPA.get_path()
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
......
...@@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else: else:
_Backend = None AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,6 +42,49 @@ pynvml = import_pynvml() ...@@ -39,6 +42,49 @@ pynvml = import_pynvml()
torch.backends.cuda.enable_cudnn_sdp(False) torch.backends.cuda.enable_cudnn_sdp(False)
@cache
def _get_backend_priorities(
use_mla: bool,
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
return [
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
if device_capability.major == 10:
return [
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn) @wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
...@@ -216,217 +262,171 @@ class CudaPlatformBase(Platform): ...@@ -216,217 +262,171 @@ class CudaPlatformBase(Platform):
return torch.cuda.max_memory_allocated(device) return torch.cuda.max_memory_allocated(device)
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": def get_vit_attn_backend(
from vllm.attention.backends.registry import _Backend cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# For Blackwell GPUs, force TORCH_SDPA for now. # For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100): if cls.has_device_capability(100):
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
return _Backend.XFORMERS return AttentionBackendEnum.XFORMERS
if cls.has_device_capability(80): if cls.has_device_capability(80):
FLASH_ATTN_V1 = ( backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if backend_class.supports_head_size(
) head_size
from vllm.attention.selector import is_attn_backend_supported ) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
is_default_fa_supported = is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
)
if is_default_fa_supported:
return _Backend.FLASH_ATTN
else: else:
# Fallback to XFORMERS return AttentionBackendEnum.XFORMERS
return _Backend.XFORMERS
else: else:
# Fallback for Volta/Turing GPUs or FA not supported # Fallback for Volta/Turing GPUs or FA not supported
return _Backend.XFORMERS return AttentionBackendEnum.XFORMERS
@classmethod @classmethod
def get_attn_backend_cls( def get_valid_backends(
cls, cls,
selected_backend,
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
use_v1,
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
) -> str: device_capability,
from vllm.attention.backends.registry import _Backend ) -> tuple[
list[tuple["AttentionBackendEnum", int]],
if use_mla: dict["AttentionBackendEnum", list[str]],
# explicitly reject non-MLA backends when MLA is enabled to avoid ]:
# silently selecting an incompatible backend (e.g., FLASHINFER). valid_backends_priorities = []
if selected_backend in { invalid_reasons = {}
_Backend.FLASHINFER,
_Backend.FLASH_ATTN, backend_priorities = _get_backend_priorities(use_mla, device_capability)
_Backend.TRITON_ATTN, for priority, backend in enumerate(backend_priorities):
_Backend.TREE_ATTN, try:
_Backend.XFORMERS, backend_class = backend.get_class()
}: invalid_reasons_i = backend_class.validate_configuration(
raise ValueError( head_size,
f"Attention backend {selected_backend} incompatible with MLA. " dtype,
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, " kv_cache_dtype,
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " block_size,
"VLLM_MLA_DISABLE=1 to disable MLA for this model." use_mla,
has_sink,
use_sparse,
device_capability,
) )
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
from vllm.attention.ops.flashmla import is_flashmla_dense_supported return valid_backends_priorities, invalid_reasons
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse: @classmethod
logger.info_once("Using Sparse MLA backend.") def get_attn_backend_cls(
return ( cls,
"vllm.v1.attention.backends.mla.flashmla_sparse." selected_backend: "AttentionBackendEnum",
"FlashMLASparseBackend" head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
) -> str:
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
) )
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( device_capability = cls.get_device_capability()
selected_backend is None assert device_capability is not None
and cls.is_device_capability(100)
and block_size % 128 == 0 # First try checking just the selected backend, if there is one.
) if selected_backend is not None:
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( try:
selected_backend is None backend_class = selected_backend.get_class()
and cls.is_device_capability(100) invalid_reasons = backend_class.validate_configuration(
and (block_size == 32 or block_size % 64 == 0) head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
device_capability,
) )
use_flashmla = selected_backend == _Backend.FLASHMLA or ( except ImportError:
selected_backend is None and is_flashmla_dense_supported()[0] invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
) )
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( else:
selected_backend is None and flash_attn_supports_mla() logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
device_capability,
) )
use_triton = selected_backend == _Backend.TRITON_MLA or ( reasons_str = (
selected_backend is None "{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
) )
+ "}"
if use_cutlassmla:
logger.info_once("Using Cutlass MLA backend.", scope="local")
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
if use_flashinfermla:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
logger.info_once("Using FlashInfer MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
) )
if use_flashmla: config_str = (
if block_size % 64 != 0: f"head_size: {head_size}, dtype: {dtype}, "
logger.warning( f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
"FlashMLA backend is not supported for block size %d" f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
" (currently only supports block size 64).",
block_size,
) )
else: logger.debug_once(
logger.info_once("Using FlashMLA backend.") f"Some attention backends are not valid for {cls.device_name} with "
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" f"{config_str}. Reasons: {reasons_str}."
if use_flashattn:
logger.info_once("Using FlashAttention MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
) )
if use_triton: if len(valid_backends_priorities) == 0:
logger.info_once("Using Triton MLA backend.") raise ValueError(
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = (
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
) )
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( # We have found some valid backends. Select the one with the
"fp8" # highest priority.
logger.info(
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
) )
sorted_indices = sorted(
if selected_backend == _Backend.FLASHINFER: range(len(valid_backends_priorities)),
logger.info_once("Using FlashInfer backend.") key=lambda i: valid_backends_priorities[i][1],
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend.")
return XFORMERS_V1
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype
):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs."
) )
set_kv_cache_layout("HND") selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
return FLASHINFER_V1 logger.info(
"Using %s backend.",
if not is_default_backend_supported.can_import: selected_backend.name,
logger.warning_once(
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
"it is recommended to install FlashInfer for better "
"performance."
) )
# FlashAttention is the default for SM 8.0+ GPUs return selected_backend.get_path()
if cls.has_device_capability(80):
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
......
...@@ -17,8 +17,9 @@ from vllm.logger import init_logger ...@@ -17,8 +17,9 @@ from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple): ...@@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple):
major: int major: int
minor: int minor: int
def __lt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) < (other.major, other.minor)
def __le__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) <= (other.major, other.minor)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) == (other.major, other.minor)
def __ge__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) >= (other.major, other.minor)
def __gt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) > (other.major, other.minor)
def as_version_str(self) -> str: def as_version_str(self) -> str:
return f"{self.major}.{self.minor}" return f"{self.major}.{self.minor}"
...@@ -173,19 +199,21 @@ class Platform: ...@@ -173,19 +199,21 @@ class Platform:
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": def get_vit_attn_backend(
# Import _Backend here to avoid circular import. cls, head_size: int, dtype: torch.dtype
from vllm.attention.backends.registry import _Backend ) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: "CacheDType | None",
block_size: int, block_size: int,
use_v1: bool, use_v1: bool,
use_mla: bool, use_mla: bool,
......
...@@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -204,21 +204,23 @@ class RocmPlatform(Platform): ...@@ -204,21 +204,23 @@ class RocmPlatform(Platform):
] ]
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled(): if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class. # TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None: if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
...@@ -234,7 +236,7 @@ class RocmPlatform(Platform): ...@@ -234,7 +236,7 @@ class RocmPlatform(Platform):
use_sparse, use_sparse,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.") raise NotImplementedError("Sparse Attention is not supported on ROCm.")
...@@ -248,55 +250,52 @@ class RocmPlatform(Platform): ...@@ -248,55 +250,52 @@ class RocmPlatform(Platform):
if use_mla: if use_mla:
if selected_backend is None: if selected_backend is None:
selected_backend = ( selected_backend = (
_Backend.ROCM_AITER_MLA AttentionBackendEnum.ROCM_AITER_MLA
if rocm_aiter_ops.is_mla_enabled() or block_size == 1 if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA else AttentionBackendEnum.TRITON_MLA
) )
if selected_backend == _Backend.TRITON_MLA: if selected_backend == AttentionBackendEnum.TRITON_MLA:
if block_size != 1: if block_size != 1:
logger.info_once("Using Triton MLA backend.") logger.info_once("Using Triton MLA backend.")
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" return AttentionBackendEnum.TRITON_MLA.get_path()
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}." f"does not support block size {block_size}."
) )
if selected_backend == _Backend.ROCM_AITER_MLA: if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
logger.info("Using AITER MLA backend.") logger.info("Using AITER MLA backend.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend." f"is not MLA type while requested for MLA backend."
) )
if selected_backend == _Backend.FLEX_ATTENTION: if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.") logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if ( if (
rocm_aiter_ops.is_mha_enabled() rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA: ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.") logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" return AttentionBackendEnum.ROCM_AITER_FA.get_path()
if ( if (
rocm_aiter_ops.is_triton_unified_attn_enabled() rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.") logger.info("Using Aiter Unified Attention backend.")
return ( return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
"vllm.v1.attention.backends."
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
)
if ( if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN or selected_backend == AttentionBackendEnum.ROCM_ATTN
): ):
# rocm specific backend, with aiter and/or # rocm specific backend, with aiter and/or
# triton prefix-prefill # triton prefix-prefill
logger.info("Using Rocm Attention backend.") logger.info("Using Rocm Attention backend.")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" return AttentionBackendEnum.ROCM_ATTN.get_path()
# default case, using triton unified attention # default case, using triton unified attention
logger.info("Using Triton Attention backend.") logger.info("Using Triton Attention backend.")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" return AttentionBackendEnum.TRITON_ATTN.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment