Unverified Commit b30dfa03 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor CUDA attention backend selection logic (#24794)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2e78150d
......@@ -890,11 +890,16 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/platforms/cuda.py
- vllm/attention/selector.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
# Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_attention_selector.py
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
......
......@@ -10,7 +10,7 @@ from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
......@@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN:
if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
......@@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
......@@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.TRITON_ATTN:
elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
......@@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.FLASHINFER:
elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
......@@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = []
BACKENDS_FP4: list[_Backend] = []
BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[AttentionBackendEnum] = []
if current_platform.is_cuda():
HEADS = [(64, 8), (40, 8)]
......@@ -261,8 +261,8 @@ if current_platform.is_cuda():
TestAttentionNvfp4QuantPatternModel,
)
]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER]
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
elif current_platform.is_rocm():
HEADS = [(32, 8), (40, 8)]
......@@ -270,9 +270,9 @@ elif current_platform.is_rocm():
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN,
_Backend.TRITON_ATTN,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.TRITON_ATTN,
]
......@@ -302,11 +302,11 @@ def test_attention_quant_pattern(
custom_ops: str,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
backend: AttentionBackendEnum,
dist_init,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and (
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
......@@ -314,6 +314,7 @@ def test_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42)
vllm_config = VllmConfig(
......@@ -402,7 +403,7 @@ def test_attention_quant_pattern(
result_fused_1 = model_compiled(q, k, v)
if backend == _Backend.FLASHINFER:
if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
......
......@@ -11,7 +11,7 @@ from typing import Any, NamedTuple
import pytest
import regex as re
from tests.v1.attention.utils import _Backend
from tests.v1.attention.utils import AttentionBackendEnum
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
......@@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
backend: AttentionBackendEnum
attention_fusions: int
allreduce_fusions: int | None = None
......@@ -39,14 +39,14 @@ if current_platform.is_cuda():
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
......@@ -56,7 +56,7 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32,
allreduce_fusions=65,
),
......@@ -67,7 +67,7 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
......@@ -85,19 +85,19 @@ elif current_platform.is_rocm():
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
......@@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
......@@ -125,7 +125,7 @@ def test_attn_quant(
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
......@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
......
......@@ -3,13 +3,13 @@
import pytest
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MultiModalConfig
def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
def test_mm_encoder_attn_backend_invalid():
......@@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
).compute_hash()
assert base_hash != overridden_hash
......@@ -120,12 +120,13 @@ def test_env(
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla:
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# and Blackwell GPUs (SM 10.x), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# (SM 10.x), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
......@@ -134,58 +135,72 @@ def test_env(
if block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip("CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if capability[0] != 10:
pytest.skip(
"FlashInfer MLA is not supported on this platform"
)
if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 or 64"
)
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)
is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
backend = get_attn_backend(
576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
)
if not flash_attn_supports_mla():
pytest.skip(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
64, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER"
assert backend.get_name() == expected
......
......@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
......@@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip":
with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
......@@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
......@@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str):
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
assert attn.attn_backend == AttentionBackendEnum.XFORMERS
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
......@@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str):
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
def ref_attention(
......
......@@ -93,6 +93,17 @@ def can_initialize(
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
)
if model_arch == "DeepseekV32ForCausalLM":
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability and capability.major < 9:
pytest.skip(
f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) "
f"for FLASHMLA_SPARSE backend. Current device has compute "
f"capability {capability.major}.{capability.minor}"
)
with (
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m,
......
......@@ -15,7 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
......@@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN,
_Backend.FLASHINFER,
_Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN,
_Backend.TREE_ATTN,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TREE_ATTN,
"FLEX_ATTENTION_SLOW",
]
......@@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [
try:
import flashinfer # noqa: F401
except ImportError:
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
def _convert_dtype_to_torch(dtype):
......@@ -192,7 +192,7 @@ class MockAttentionLayer:
def run_attention_backend(
backend: _Backend,
backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
......@@ -211,13 +211,13 @@ def run_attention_backend(
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION
actual_backend = AttentionBackendEnum.FLEX_ATTENTION
use_direct_block_mask = False
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER:
if actual_backend == AttentionBackendEnum.FLASHINFER:
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
......@@ -246,7 +246,7 @@ def run_attention_backend(
else:
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
if actual_backend == _Backend.FLEX_ATTENTION:
if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
builder.direct_build = use_direct_block_mask
attn_metadata = builder.build(
common_prefix_len=0,
......@@ -289,7 +289,7 @@ def run_attention_backend(
def _test_backend_correctness(
batch_spec: BatchSpec,
model: str,
backend_to_test: list[_Backend | str],
backend_to_test: list[AttentionBackendEnum | str],
mask_mod,
*,
block_size: int = 16,
......@@ -455,17 +455,20 @@ def _test_backend_correctness(
# Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache
reset_kv_cache_layout = False
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
if backend_name in (
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
):
kv_cache_for_backend = kv_cache.transpose(0, 1)
if backend_name == _Backend.FLASHINFER:
if backend_name == AttentionBackendEnum.FLASHINFER:
# For FlashInfer default to HND layout and
kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
)
set_kv_cache_layout("HND")
reset_kv_cache_layout = True
elif backend_name == _Backend.TRITON_ATTN:
elif backend_name == AttentionBackendEnum.TRITON_ATTN:
kv_cache_for_backend = kv_cache_for_backend.contiguous()
try:
......@@ -547,7 +550,9 @@ def test_causal_backend_correctness(
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
[AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
)
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
......@@ -573,9 +578,9 @@ def test_causal_backend_correctness(
SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN,
_Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
......@@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness(
)
LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
[AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
)
SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
......
......@@ -18,12 +18,11 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
......@@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA,
_Backend.FLASHMLA,
_Backend.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA,
_Backend.TRITON_MLA,
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.TRITON_MLA,
]
# Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
......@@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend)
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
supported_sizes = backend.get_class().supported_kernel_block_sizes
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
......@@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend(
backend: _Backend,
backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
......@@ -813,7 +810,7 @@ def test_backend_correctness(
# Create a summary for the single-line failure message
backend_names = []
for f in failures:
if "[_Backend." in f:
if "[AttentionBackendEnum." in f:
backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name)
......
......@@ -8,7 +8,7 @@ import pytest
import torch
from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
CompilationConfig,
......@@ -20,7 +20,6 @@ from vllm.config import (
VllmConfig,
)
from vllm.config.model import ModelDType
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......@@ -120,15 +119,14 @@ def create_common_attn_metadata(
def try_get_attention_backend(
backend: _Backend,
backend: AttentionBackendEnum,
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
"""Try to get the attention backend class, skipping test if not found."""
backend_class_str = backend_to_class_str(backend)
try:
backend_class = resolve_obj_by_qualname(backend_class_str)
backend_class = backend.get_class()
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e:
pytest.skip(f"{backend_class_str} not available: {e}")
pytest.skip(f"{backend.name} not available: {e}")
raise AssertionError("unreachable") from None
......
......@@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
......@@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
......@@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
......
......@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
......@@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
......
......@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
......@@ -35,7 +35,7 @@ def forward_attention(
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: _Backend,
backend: AttentionBackendEnum,
spec_token_tree: str | None = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
......@@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
backend=AttentionBackendEnum.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
......@@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN,
backend=AttentionBackendEnum.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
......
......@@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
@staticmethod
def get_supported_kernel_block_size():
return supported_sizes
supported_kernel_block_sizes = supported_sizes
return _MockBackend()
......@@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
expected_kv_cache_shape = [
2,
NUM_BLOCKS,
BLOCK_SIZE,
n_heads,
model_runner.model_config.get_head_size(),
]
head_size = model_runner.model_config.get_head_size()
# Get the expected shape from the backend's get_kv_cache_shape method
# to ensure compatibility with different backends (triton vs flexattention)
attn_backend = None
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
break
assert attn_backend is not None, "No attention backend found"
expected_kv_cache_shape = list(
attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
)
# TODO mla test
default_stride = tuple(range(5))
# Permutation that gets you back to expected kv shape
......
......@@ -2,13 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Generic, Protocol, TypeVar
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType:
"""
......@@ -40,6 +45,9 @@ class AttentionBackend(ABC):
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod
@abstractmethod
......@@ -51,10 +59,6 @@ class AttentionBackend(ABC):
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return cls.get_impl_cls().get_supported_kernel_block_size()
@staticmethod
@abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
......@@ -79,6 +83,136 @@ class AttentionBackend(ABC):
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls.supported_dtypes
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
return (not cls.supported_kv_cache_dtypes) or (
kv_cache_dtype in cls.supported_kv_cache_dtypes
)
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
if not cls.supported_kernel_block_sizes:
return True
for supported_size in cls.supported_kernel_block_sizes:
is_multiple_of = (
isinstance(supported_size, MultipleOf)
and block_size % supported_size.base == 0
)
is_int_equal = (
isinstance(supported_size, int) and block_size == supported_size
)
if is_multiple_of or is_int_equal:
return True
return False
@classmethod
def is_mla(cls) -> bool:
return False
@classmethod
def supports_sink(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None
@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
class AttentionMetadata:
pass
......@@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]):
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]
@abstractmethod
def forward(
self,
......
......@@ -3,108 +3,192 @@
"""Attention backend registry"""
import enum
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASHMLA_SPARSE = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()
BACKEND_MAP = {
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501
_Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
}
def register_attn_backend(backend: _Backend, class_path: str | None = None):
"""
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed. This enables other hardware
platforms to plug in custom out-of-tree backends.
"""
if not isinstance(backend, _Backend):
raise ValueError(f"{backend} is not a valid _Backend enum value.")
logger = init_logger(__name__)
def decorator(cls):
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
BACKEND_MAP[backend] = path
return cls
return decorator
class _AttentionBackendEnumMeta(enum.EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name)
except KeyError:
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
valid_backends = ", ".join(m.name for m in members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
def backend_to_class_str(backend: _Backend) -> str:
"""Get the backend class string
Args:
backend: The backend enum value
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.
Returns:
The backend class string
"""
return BACKEND_MAP[backend]
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
def backend_to_class(backend: _Backend) -> type:
"""Get the backend class.
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
ROCM_AITER_UNIFIED_ATTN = (
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_OVERRIDES.pop(self, None)
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
def register_backend(
backend: AttentionBackendEnum, class_path: str | None = None
) -> Callable[[type], type]:
"""Register or override a backend implementation.
Args:
backend: The backend enum value
backend: The AttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns:
The backend class
Decorator function if class_path is None, otherwise a no-op
Examples:
# Override an existing backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Register a custom third-party backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
# Direct registration
register_backend(
AttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
backend_class_name = backend_to_class_str(backend)
return resolve_obj_by_qualname(backend_class_name)
def decorator(cls: type) -> type:
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
return cls
def backend_name_to_enum(backend_name: str) -> _Backend | None:
"""
Convert a string backend name to a _Backend enum value.
if class_path is not None:
_OVERRIDES[backend] = class_path
return lambda x: x
Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
return decorator
# Backwards compatibility alias for plugins
class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
def __getattribute__(cls, name: str):
if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
def __getitem__(cls, name: str):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
pass
......@@ -12,7 +12,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
......@@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype):
def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
attn_backend: AttentionBackendEnum,
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable | None]:
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9()
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
else:
return _Backend.TORCH_SDPA, None
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend = _Backend.FLASH_ATTN
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, (
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else:
return _Backend.TORCH_SDPA, None
return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if attn_backend == _Backend.ROCM_AITER_FA:
if attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
......@@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
......@@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module):
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
else AttentionBackendEnum.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
......@@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module):
)
)
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
self.attn_backend = _Backend.TORCH_SDPA
if (
self.attn_backend == AttentionBackendEnum.XFORMERS
and not check_xformers_availability()
):
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True
logger.info_once(
......@@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS:
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
......
......@@ -4,14 +4,15 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
......@@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_env_variable_attn_backend() -> _Backend | None:
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else backend_name_to_enum(backend_name)
return None if backend_name is None else AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
......@@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: _Backend | None = None
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
......@@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> _Backend | None:
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
......@@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
return forced_attn_backend
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: str | type[AttentionBackend],
head_size: int,
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(
attn_backend, "get_supported_head_sizes", None
):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
is_head_size_supported = False
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support head size validation"
)
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support dtype validation"
)
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
......@@ -149,8 +100,8 @@ def get_attn_backend(
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
......@@ -161,7 +112,9 @@ def _cached_get_attn_backend(
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
......@@ -177,12 +130,13 @@ def _cached_get_attn_backend(
STR_BACKEND_ENV_VAR,
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
try:
selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) from e
# get device-specific attn_backend
from vllm.platforms import current_platform
......@@ -202,12 +156,26 @@ def _cached_get_attn_backend(
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
backend = resolve_obj_by_qualname(attention_cls)
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout(required_layout)
logger.info(
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
return backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend,
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a
......
......@@ -21,7 +21,15 @@ else:
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
CacheDType = Literal[
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
......
......@@ -45,7 +45,7 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -53,7 +53,7 @@ if TYPE_CHECKING:
else:
PretrainedConfig = Any
_Backend = Any
AttentionBackendEnum = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
......@@ -302,7 +302,7 @@ class ModelConfig:
mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
......@@ -420,7 +420,7 @@ class ModelConfig:
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: _Backend | str | None,
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment