Unverified Commit b30dfa03 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor CUDA attention backend selection logic (#24794)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2e78150d
...@@ -890,11 +890,16 @@ steps: ...@@ -890,11 +890,16 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/platforms/cuda.py
- vllm/attention/selector.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
# Attention # Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_attention_selector.py
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
......
...@@ -10,7 +10,7 @@ from tests.utils import flat_product ...@@ -10,7 +10,7 @@ from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
...@@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
# TODO(luka) use get_kv_cache_stride_order # TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend # Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN: if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention # k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size] # HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
...@@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention # k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size] # NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
...@@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.TRITON_ATTN: elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention # k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size] # NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
...@@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.FLASHINFER: elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros( kv_cache = torch.zeros(
num_blocks, num_blocks,
2, 2,
...@@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = [] ...@@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = [] MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = [] HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = [] SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = [] BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[_Backend] = [] BACKENDS_FP4: list[AttentionBackendEnum] = []
if current_platform.is_cuda(): if current_platform.is_cuda():
HEADS = [(64, 8), (40, 8)] HEADS = [(64, 8), (40, 8)]
...@@ -261,8 +261,8 @@ if current_platform.is_cuda(): ...@@ -261,8 +261,8 @@ if current_platform.is_cuda():
TestAttentionNvfp4QuantPatternModel, TestAttentionNvfp4QuantPatternModel,
) )
] ]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER] BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
HEADS = [(32, 8), (40, 8)] HEADS = [(32, 8), (40, 8)]
...@@ -270,9 +270,9 @@ elif current_platform.is_rocm(): ...@@ -270,9 +270,9 @@ elif current_platform.is_rocm():
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
] ]
BACKENDS = [ BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN, AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN, AttentionBackendEnum.ROCM_ATTN,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
] ]
...@@ -302,11 +302,11 @@ def test_attention_quant_pattern( ...@@ -302,11 +302,11 @@ def test_attention_quant_pattern(
custom_ops: str, custom_ops: str,
model_name: str, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, backend: AttentionBackendEnum,
dist_init, dist_init,
): ):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
...@@ -314,6 +314,7 @@ def test_attention_quant_pattern( ...@@ -314,6 +314,7 @@ def test_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
vllm_config = VllmConfig( vllm_config = VllmConfig(
...@@ -402,7 +403,7 @@ def test_attention_quant_pattern( ...@@ -402,7 +403,7 @@ def test_attention_quant_pattern(
result_fused_1 = model_compiled(q, k, v) result_fused_1 = model_compiled(q, k, v)
if backend == _Backend.FLASHINFER: if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward # With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's # pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded # _o_scale_float, the 2nd round should reuse the loaded
......
...@@ -11,7 +11,7 @@ from typing import Any, NamedTuple ...@@ -11,7 +11,7 @@ from typing import Any, NamedTuple
import pytest import pytest
import regex as re import regex as re
from tests.v1.attention.utils import _Backend from tests.v1.attention.utils import AttentionBackendEnum
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test ...@@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple): class ModelBackendTestCase(NamedTuple):
model_name: str model_name: str
model_kwargs: dict[str, Any] model_kwargs: dict[str, Any]
backend: _Backend backend: AttentionBackendEnum
attention_fusions: int attention_fusions: int
allreduce_fusions: int | None = None allreduce_fusions: int | None = None
...@@ -39,14 +39,14 @@ if current_platform.is_cuda(): ...@@ -39,14 +39,14 @@ if current_platform.is_cuda():
# Use smaller model for L40s in CI # Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, attention_fusions=32,
allreduce_fusions=65, allreduce_fusions=65,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48, attention_fusions=48,
allreduce_fusions=96, allreduce_fusions=96,
), ),
...@@ -56,7 +56,7 @@ if current_platform.is_cuda(): ...@@ -56,7 +56,7 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32, attention_fusions=32,
allreduce_fusions=65, allreduce_fusions=65,
), ),
...@@ -67,7 +67,7 @@ if current_platform.is_cuda(): ...@@ -67,7 +67,7 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct", model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, attention_fusions=0,
allreduce_fusions=65, allreduce_fusions=65,
), ),
...@@ -85,19 +85,19 @@ elif current_platform.is_rocm(): ...@@ -85,19 +85,19 @@ elif current_platform.is_rocm():
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN, backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN, backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
] ]
...@@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] ...@@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def test_attn_quant( def test_attn_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
backend: _Backend, backend: AttentionBackendEnum,
attention_fusions: int, attention_fusions: int,
allreduce_fusions: int, allreduce_fusions: int,
custom_ops: str, custom_ops: str,
...@@ -125,7 +125,7 @@ def test_attn_quant( ...@@ -125,7 +125,7 @@ def test_attn_quant(
caplog_mp_spawn, caplog_mp_spawn,
monkeypatch, monkeypatch,
): ):
if backend == _Backend.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
...@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: ...@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm( def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str, model_name: str,
model_kwargs: dict, model_kwargs: dict,
backend: _Backend, backend: AttentionBackendEnum,
attention_fusions: int, attention_fusions: int,
allreduce_fusions: int, allreduce_fusions: int,
custom_ops: str, custom_ops: str,
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
import pytest import pytest
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
def test_mm_encoder_attn_backend_str_conversion(): def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_invalid():
...@@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid(): ...@@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
def test_mm_encoder_attn_backend_hash_updates(): def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash() base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig( overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
).compute_hash() ).compute_hash()
assert base_hash != overridden_hash assert base_hash != overridden_hash
...@@ -120,12 +120,13 @@ def test_env( ...@@ -120,12 +120,13 @@ def test_env(
elif device == "cuda": elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla: if use_mla:
# CUDA MLA backend logic: # CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128 # - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only # and Blackwell GPUs (SM 10.x), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs # - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only # (SM 10.x), V1 only
# - FLASHMLA: only supported with block_size == 64 # - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only # - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases # - TRITON_MLA: fallback for other cases
...@@ -134,21 +135,25 @@ def test_env( ...@@ -134,21 +135,25 @@ def test_env(
if block_size != 128: if block_size != 128:
# CUTLASS_MLA only supports block_size == 128 # CUTLASS_MLA only supports block_size == 128
pytest.skip("CUTLASS_MLA only supports block_size 128") pytest.skip("CUTLASS_MLA only supports block_size 128")
else: if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "CUTLASS_MLA" expected = "CUTLASS_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER_MLA": elif name == "FLASHINFER_MLA":
if capability[0] != 10:
pytest.skip(
"FlashInfer MLA is not supported on this platform"
)
if block_size not in [32, 64]: if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64 # FlashInfer MLA only supports block_size 32 or 64
pytest.skip( pytest.skip(
"FlashInfer MLA only supports block_size 32 or 64" "FlashInfer MLA only supports block_size 32 or 64"
) )
else:
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "FLASHINFER_MLA" expected = "FLASHINFER_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
...@@ -156,7 +161,6 @@ def test_env( ...@@ -156,7 +161,6 @@ def test_env(
if block_size != 64: if block_size != 64:
# FlashMLA only supports block_size == 64 # FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64") pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import ( from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
...@@ -164,28 +168,39 @@ def test_env( ...@@ -164,28 +168,39 @@ def test_env(
is_supported, _ = is_flashmla_dense_supported() is_supported, _ = is_flashmla_dense_supported()
if not is_supported: if not is_supported:
pytest.skip("FlashMLA not supported on this platform") pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576,
torch.float16,
None,
block_size,
use_mla=use_mla,
) )
expected = name expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA": elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
)
if not flash_attn_supports_mla():
pytest.skip(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "FLASH_ATTN_MLA" expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
# TRITON_MLA or other fallback # TRITON_MLA or other fallback
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "TRITON_MLA" expected = "TRITON_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER": elif name == "FLASHINFER":
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 64, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "FLASHINFER" expected = "FLASHINFER"
assert backend.get_name() == expected assert backend.get_name() == expected
......
...@@ -11,7 +11,7 @@ from unittest.mock import patch ...@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str): ...@@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip": elif device == "hip":
with ( with (
patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
...@@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str): ...@@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available # - with upstream FA not available
...@@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str): ...@@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str):
), ),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS assert attn.attn_backend == AttentionBackendEnum.XFORMERS
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available # - with upstream FA available
...@@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str): ...@@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str):
), ),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
def ref_attention( def ref_attention(
......
...@@ -93,6 +93,17 @@ def can_initialize( ...@@ -93,6 +93,17 @@ def can_initialize(
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
) )
if model_arch == "DeepseekV32ForCausalLM":
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability and capability.major < 9:
pytest.skip(
f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) "
f"for FLASHMLA_SPARSE backend. Current device has compute "
f"capability {capability.major}.{capability.minor}"
)
with ( with (
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,
......
...@@ -15,7 +15,7 @@ from tests.v1.attention.utils import ( ...@@ -15,7 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import ( ...@@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.FLASHINFER, AttentionBackendEnum.FLASHINFER,
_Backend.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
_Backend.TREE_ATTN, AttentionBackendEnum.TREE_ATTN,
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
...@@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [ ...@@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
except ImportError: except ImportError:
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
def _convert_dtype_to_torch(dtype): def _convert_dtype_to_torch(dtype):
...@@ -192,7 +192,7 @@ class MockAttentionLayer: ...@@ -192,7 +192,7 @@ class MockAttentionLayer:
def run_attention_backend( def run_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
...@@ -211,13 +211,13 @@ def run_attention_backend( ...@@ -211,13 +211,13 @@ def run_attention_backend(
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW": if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION actual_backend = AttentionBackendEnum.FLEX_ATTENTION
use_direct_block_mask = False use_direct_block_mask = False
builder_cls, impl_cls = try_get_attention_backend(actual_backend) builder_cls, impl_cls = try_get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed # Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER: if actual_backend == AttentionBackendEnum.FLASHINFER:
import unittest.mock import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters from vllm.v1.attention.backends.utils import PerLayerParameters
...@@ -246,7 +246,7 @@ def run_attention_backend( ...@@ -246,7 +246,7 @@ def run_attention_backend(
else: else:
# Build metadata # Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
if actual_backend == _Backend.FLEX_ATTENTION: if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
builder.direct_build = use_direct_block_mask builder.direct_build = use_direct_block_mask
attn_metadata = builder.build( attn_metadata = builder.build(
common_prefix_len=0, common_prefix_len=0,
...@@ -289,7 +289,7 @@ def run_attention_backend( ...@@ -289,7 +289,7 @@ def run_attention_backend(
def _test_backend_correctness( def _test_backend_correctness(
batch_spec: BatchSpec, batch_spec: BatchSpec,
model: str, model: str,
backend_to_test: list[_Backend | str], backend_to_test: list[AttentionBackendEnum | str],
mask_mod, mask_mod,
*, *,
block_size: int = 16, block_size: int = 16,
...@@ -455,17 +455,20 @@ def _test_backend_correctness( ...@@ -455,17 +455,20 @@ def _test_backend_correctness(
# Select the appropriate KV cache format for each backend # Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache kv_cache_for_backend = kv_cache
reset_kv_cache_layout = False reset_kv_cache_layout = False
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): if backend_name in (
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
):
kv_cache_for_backend = kv_cache.transpose(0, 1) kv_cache_for_backend = kv_cache.transpose(0, 1)
if backend_name == _Backend.FLASHINFER: if backend_name == AttentionBackendEnum.FLASHINFER:
# For FlashInfer default to HND layout and # For FlashInfer default to HND layout and
kv_cache_for_backend = ( kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
) )
set_kv_cache_layout("HND") set_kv_cache_layout("HND")
reset_kv_cache_layout = True reset_kv_cache_layout = True
elif backend_name == _Backend.TRITON_ATTN: elif backend_name == AttentionBackendEnum.TRITON_ATTN:
kv_cache_for_backend = kv_cache_for_backend.contiguous() kv_cache_for_backend = kv_cache_for_backend.contiguous()
try: try:
...@@ -547,7 +550,9 @@ def test_causal_backend_correctness( ...@@ -547,7 +550,9 @@ def test_causal_backend_correctness(
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = ( LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] [AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
) )
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
...@@ -573,9 +578,9 @@ def test_causal_backend_correctness( ...@@ -573,9 +578,9 @@ def test_causal_backend_correctness(
SLIDING_WINDOW_BACKENDS_TO_TEST = [ SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
...@@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness( ...@@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness(
) )
LARGE_BLOCK_BACKENDS = ( LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] [AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
) )
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
......
...@@ -18,12 +18,11 @@ from tests.v1.attention.utils import ( ...@@ -18,12 +18,11 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
...@@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata ...@@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
_Backend.FLASHMLA, AttentionBackendEnum.FLASHMLA,
_Backend.FLASH_ATTN_MLA, AttentionBackendEnum.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA, AttentionBackendEnum.FLASHINFER_MLA,
_Backend.TRITON_MLA, AttentionBackendEnum.TRITON_MLA,
] ]
# Remove sm100 backends from the list if not using sm100 # Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported # Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla(): if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported # Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]: if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
SPEC_DECODE_BACKENDS = [] SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST: for backend in BACKENDS_TO_TEST:
...@@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST: ...@@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {} BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST: for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend) supported_sizes = backend.get_class().supported_kernel_block_sizes
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes: if supported_sizes:
default_size = supported_sizes[0] default_size = supported_sizes[0]
block_size = ( block_size = (
...@@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
...@@ -813,7 +810,7 @@ def test_backend_correctness( ...@@ -813,7 +810,7 @@ def test_backend_correctness(
# Create a summary for the single-line failure message # Create a summary for the single-line failure message
backend_names = [] backend_names = []
for f in failures: for f in failures:
if "[_Backend." in f: if "[AttentionBackendEnum." in f:
backend_name = f.split("[")[1].split("]")[0] backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name) backend_names.append(backend_name)
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import torch import torch
from vllm.attention.backends.abstract import AttentionImpl from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
...@@ -20,7 +20,6 @@ from vllm.config import ( ...@@ -20,7 +20,6 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -120,15 +119,14 @@ def create_common_attn_metadata( ...@@ -120,15 +119,14 @@ def create_common_attn_metadata(
def try_get_attention_backend( def try_get_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
"""Try to get the attention backend class, skipping test if not found.""" """Try to get the attention backend class, skipping test if not found."""
backend_class_str = backend_to_class_str(backend)
try: try:
backend_class = resolve_obj_by_qualname(backend_class_str) backend_class = backend.get_class()
return backend_class.get_builder_cls(), backend_class.get_impl_cls() return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e: except ImportError as e:
pytest.skip(f"{backend_class_str} not available: {e}") pytest.skip(f"{backend.name} not available: {e}")
raise AssertionError("unreachable") from None raise AssertionError("unreachable") from None
......
...@@ -13,7 +13,7 @@ from tests.v1.attention.utils import ( ...@@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
...@@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN": elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN": elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
else: else:
raise ValueError(f"Unsupported attention backend: {attn_backend}") raise ValueError(f"Unsupported attention backend: {attn_backend}")
...@@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree): ...@@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"] proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder. # Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer.attn_layer_names,
......
...@@ -12,7 +12,7 @@ from tests.v1.attention.utils import ( ...@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
...@@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): ...@@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
# Setup attention metadata # Setup attention metadata
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
......
...@@ -10,7 +10,7 @@ from tests.v1.attention.utils import ( ...@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
...@@ -35,7 +35,7 @@ def forward_attention( ...@@ -35,7 +35,7 @@ def forward_attention(
block_table: torch.Tensor, block_table: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
seqlen_k: int, seqlen_k: int,
backend: _Backend, backend: AttentionBackendEnum,
spec_token_tree: str | None = None, spec_token_tree: str | None = None,
num_spec_tokens: int = 0, num_spec_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None: ...@@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=tree_slot_mapping, slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k, seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN, backend=AttentionBackendEnum.TREE_ATTN,
spec_token_tree=spec_token_tree, spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1, num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
...@@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None: ...@@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=branch_slot_mapping, slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len, seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN, backend=AttentionBackendEnum.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs. # Compare the outputs.
......
...@@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size( ...@@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf], supported_sizes: list[int | MultipleOf],
): ):
class _MockBackend: class _MockBackend:
@staticmethod supported_kernel_block_sizes = supported_sizes
def get_supported_kernel_block_size():
return supported_sizes
return _MockBackend() return _MockBackend()
...@@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): ...@@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention # This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order. # backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
expected_kv_cache_shape = [ head_size = model_runner.model_config.get_head_size()
2,
NUM_BLOCKS, # Get the expected shape from the backend's get_kv_cache_shape method
BLOCK_SIZE, # to ensure compatibility with different backends (triton vs flexattention)
n_heads, attn_backend = None
model_runner.model_config.get_head_size(), for attn_group in model_runner._attn_group_iterator():
] attn_backend = attn_group.backend
break
assert attn_backend is not None, "No attention backend found"
expected_kv_cache_shape = list(
attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
)
# TODO mla test # TODO mla test
default_stride = tuple(range(5)) default_stride = tuple(range(5))
# Permutation that gets you back to expected kv shape # Permutation that gets you back to expected kv shape
......
...@@ -2,13 +2,18 @@ ...@@ -2,13 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, Protocol, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType: class AttentionType:
""" """
...@@ -40,6 +45,9 @@ class AttentionBackend(ABC): ...@@ -40,6 +45,9 @@ class AttentionBackend(ABC):
# calling the custom op. When piecewise cudagraph is enabled, this # calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph. # makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod @staticmethod
@abstractmethod @abstractmethod
...@@ -51,10 +59,6 @@ class AttentionBackend(ABC): ...@@ -51,10 +59,6 @@ class AttentionBackend(ABC):
def get_impl_cls() -> type["AttentionImpl"]: def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return cls.get_impl_cls().get_supported_kernel_block_size()
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
...@@ -79,6 +83,136 @@ class AttentionBackend(ABC): ...@@ -79,6 +83,136 @@ class AttentionBackend(ABC):
def full_cls_name(cls) -> tuple[str, str]: def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__) return (cls.__module__, cls.__qualname__)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls.supported_dtypes
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
return (not cls.supported_kv_cache_dtypes) or (
kv_cache_dtype in cls.supported_kv_cache_dtypes
)
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
if not cls.supported_kernel_block_sizes:
return True
for supported_size in cls.supported_kernel_block_sizes:
is_multiple_of = (
isinstance(supported_size, MultipleOf)
and block_size % supported_size.base == 0
)
is_int_equal = (
isinstance(supported_size, int) and block_size == supported_size
)
if is_multiple_of or is_int_equal:
return True
return False
@classmethod
def is_mla(cls) -> bool:
return False
@classmethod
def supports_sink(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None
@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
class AttentionMetadata: class AttentionMetadata:
pass pass
...@@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
......
...@@ -3,108 +3,192 @@ ...@@ -3,108 +3,192 @@
"""Attention backend registry""" """Attention backend registry"""
import enum import enum
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class _Backend(enum.Enum): logger = init_logger(__name__)
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASHMLA_SPARSE = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()
BACKEND_MAP = {
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501
_Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
}
def register_attn_backend(backend: _Backend, class_path: str | None = None):
"""
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed. This enables other hardware
platforms to plug in custom out-of-tree backends.
"""
if not isinstance(backend, _Backend):
raise ValueError(f"{backend} is not a valid _Backend enum value.")
def decorator(cls):
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
BACKEND_MAP[backend] = path
return cls
return decorator class _AttentionBackendEnumMeta(enum.EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name)
except KeyError:
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
valid_backends = ", ".join(m.name for m in members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
def backend_to_class_str(backend: _Backend) -> str:
"""Get the backend class string
Args: class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
backend: The backend enum value """Enumeration of all supported attention backends.
Returns: The enum value is the default class path, but this can be overridden
The backend class string at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
""" """
return BACKEND_MAP[backend]
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
ROCM_AITER_UNIFIED_ATTN = (
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
def backend_to_class(backend: _Backend) -> type: Returns:
"""Get the backend class. The fully qualified class path string
Args: Raises:
backend: The backend enum value ValueError: If Backend.CUSTOM is used without being registered
"""
path = _OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns: Returns:
The backend class The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
""" """
backend_class_name = backend_to_class_str(backend) return resolve_obj_by_qualname(self.get_path())
return resolve_obj_by_qualname(backend_class_name)
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
def backend_name_to_enum(backend_name: str) -> _Backend | None: Returns:
True if the backend has a registered override
""" """
Convert a string backend name to a _Backend enum value. return self in _OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_OVERRIDES.pop(self, None)
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
def register_backend(
backend: AttentionBackendEnum, class_path: str | None = None
) -> Callable[[type], type]:
"""Register or override a backend implementation.
Args:
backend: The AttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns: Returns:
_Backend: enum value if backend_name is a valid in-tree type Decorator function if class_path is None, otherwise a no-op
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded. Examples:
# Override an existing backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Register a custom third-party backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
# Direct registration
register_backend(
AttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
def decorator(cls: type) -> type:
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
return cls
if class_path is not None:
_OVERRIDES[backend] = class_path
return lambda x: x
return decorator
# Backwards compatibility alias for plugins
class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
def __getattribute__(cls, name: str):
if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
def __getitem__(cls, name: str):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
""" """
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None pass
...@@ -12,7 +12,7 @@ import torch.nn.functional as F ...@@ -12,7 +12,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
...@@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype): ...@@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype):
def maybe_get_vit_flash_attn_backend( def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend, attn_backend: AttentionBackendEnum,
use_upstream_fa: bool, use_upstream_fa: bool,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[_Backend, Callable | None]: ) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm(): if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif ( elif (
check_upstream_fa_availability(torch.get_default_dtype()) check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9() and on_gfx9()
and attn_backend_override is None and attn_backend_override is None
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
else: else:
return _Backend.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda(): elif current_platform.is_cuda():
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
elif current_platform.is_xpu(): elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, ( assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend." "XPU platform only supports FLASH_ATTN as vision attention backend."
) )
use_upstream_fa = False use_upstream_fa = False
else: else:
return _Backend.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend in {
if attn_backend == _Backend.ROCM_AITER_FA: AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
if use_upstream_fa: if use_upstream_fa:
...@@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
**extra_impl_args, **extra_impl_args,
) )
self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
...@@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module): ...@@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module):
backend backend
if backend if backend
in { in {
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.PALLAS, AttentionBackendEnum.PALLAS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
} }
else _Backend.TORCH_SDPA else AttentionBackendEnum.TORCH_SDPA
) )
self.attn_backend, self._flash_attn_varlen_func = ( self.attn_backend, self._flash_attn_varlen_func = (
...@@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module): ...@@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module):
) )
) )
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): if (
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend == AttentionBackendEnum.XFORMERS
and not check_xformers_availability()
):
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
# this condition is just to make sure that the # this condition is just to make sure that the
# use_upstream_fa in the log is correct # use_upstream_fa in the log is correct
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True use_upstream_fa = True
logger.info_once( logger.info_once(
...@@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module): ...@@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len, max_seqlen_k=kv_len,
softmax_scale=self.scale, softmax_scale=self.scale,
) )
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale query, key, value, scale=self.scale
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS: elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention from torch_xla.experimental.custom_kernel import flash_attention
......
...@@ -4,14 +4,15 @@ ...@@ -4,14 +4,15 @@
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache from functools import cache
from typing import cast, get_args
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
...@@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
def get_env_variable_attn_backend() -> _Backend | None: def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
""" """
Get the backend override specified by the vLLM attention Get the backend override specified by the vLLM attention
backend environment variable, if one is specified. backend environment variable, if one is specified.
Returns: Returns:
* _Backend enum value if an override is specified * AttentionBackendEnum value if an override is specified
* None otherwise * None otherwise
""" """
backend_name = os.environ.get(STR_BACKEND_ENV_VAR) backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else backend_name_to_enum(backend_name) return None if backend_name is None else AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend # Global state allows a particular choice of backend
...@@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None: ...@@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
# #
# THIS SELECTION TAKES PRECEDENCE OVER THE # THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: _Backend | None = None forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: _Backend | None) -> None: def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
""" """
Force all attention operations to use a specified backend. Force all attention operations to use a specified backend.
...@@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None: ...@@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
forced_attn_backend = attn_backend forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> _Backend | None: def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
""" """
Get the currently-forced choice of attention backend, Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled. or None if auto-selection is currently enabled.
...@@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None: ...@@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
return forced_attn_backend return forced_attn_backend
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: str | type[AttentionBackend],
head_size: int,
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(
attn_backend, "get_supported_head_sizes", None
):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
is_head_size_supported = False
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support head size validation"
)
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support dtype validation"
)
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
block_size: int, block_size: int | None,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
return _cached_get_attn_backend( return _cached_get_attn_backend(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size, block_size=block_size,
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
...@@ -149,8 +100,8 @@ def get_attn_backend( ...@@ -149,8 +100,8 @@ def get_attn_backend(
def _cached_get_attn_backend( def _cached_get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
...@@ -161,7 +112,9 @@ def _cached_get_attn_backend( ...@@ -161,7 +112,9 @@ def _cached_get_attn_backend(
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE. # ENVIRONMENT VARIABLE.
selected_backend = None selected_backend = None
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None: if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting selected_backend = backend_by_global_setting
else: else:
...@@ -177,12 +130,13 @@ def _cached_get_attn_backend( ...@@ -177,12 +130,13 @@ def _cached_get_attn_backend(
STR_BACKEND_ENV_VAR, STR_BACKEND_ENV_VAR,
) )
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var) try:
if selected_backend is None: selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError( raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. " f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"Valid backends are: {list(_Backend.__members__.keys())}" f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) ) from e
# get device-specific attn_backend # get device-specific attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -202,12 +156,26 @@ def _cached_get_attn_backend( ...@@ -202,12 +156,26 @@ def _cached_get_attn_backend(
raise ValueError( raise ValueError(
f"Invalid attention backend for {current_platform.device_name}" f"Invalid attention backend for {current_platform.device_name}"
) )
return resolve_obj_by_qualname(attention_cls) backend = resolve_obj_by_qualname(attention_cls)
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout(required_layout)
logger.info(
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
return backend
@contextmanager @contextmanager
def global_force_attn_backend_context_manager( def global_force_attn_backend_context_manager(
attn_backend: _Backend, attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
""" """
Globally force a vLLM attention backend override within a Globally force a vLLM attention backend override within a
......
...@@ -21,7 +21,15 @@ else: ...@@ -21,7 +21,15 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] CacheDType = Literal[
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"] KVOffloadingBackend = Literal["native", "lmcache"]
......
...@@ -45,7 +45,7 @@ if TYPE_CHECKING: ...@@ -45,7 +45,7 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -53,7 +53,7 @@ if TYPE_CHECKING: ...@@ -53,7 +53,7 @@ if TYPE_CHECKING:
else: else:
PretrainedConfig = Any PretrainedConfig = Any
_Backend = Any AttentionBackendEnum = Any
me_quant = LazyLoader( me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization" "model_executor", globals(), "vllm.model_executor.layers.quantization"
) )
...@@ -302,7 +302,7 @@ class ModelConfig: ...@@ -302,7 +302,7 @@ class ModelConfig:
mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None video_pruning_rate: InitVar[float | None] = None
...@@ -420,7 +420,7 @@ class ModelConfig: ...@@ -420,7 +420,7 @@ class ModelConfig:
mm_processor_cache_type: MMCacheType | None, mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None, mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_tp_mode: MMEncoderTPMode | None, mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: _Backend | str | None, mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None, interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None, skip_mm_profiling: bool | None,
video_pruning_rate: float | None, video_pruning_rate: float | None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment