Unverified Commit b30dfa03 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor CUDA attention backend selection logic (#24794)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2e78150d
...@@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS ...@@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
else: else:
BlockSize = None BlockSize = None
ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,7 +53,7 @@ class TpuPlatform(Platform): ...@@ -54,7 +53,7 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
...@@ -64,17 +63,17 @@ class TpuPlatform(Platform): ...@@ -64,17 +63,17 @@ class TpuPlatform(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
) -> str: ) -> str:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1: if not use_v1:
raise ValueError("TPU backend only supports V1.") raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.") logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" return AttentionBackendEnum.PALLAS.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
......
...@@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS ...@@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
else: else:
ModelConfig = None
VllmConfig = None VllmConfig = None
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,7 +43,7 @@ class XPUPlatform(Platform): ...@@ -44,7 +43,7 @@ class XPUPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
...@@ -62,18 +61,19 @@ class XPUPlatform(Platform): ...@@ -62,18 +61,19 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 use_v1 = envs.VLLM_USE_V1
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if not use_v1:
if selected_backend == _Backend.TRITON_ATTN: raise ValueError("XPU backend only supports V1.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.") logger.info_once("Using Triton backend.")
return TRITON_ATTN return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.") logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN.get_path()
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
...@@ -81,7 +81,7 @@ class XPUPlatform(Platform): ...@@ -81,7 +81,7 @@ class XPUPlatform(Platform):
) )
logger.info("Using Flash Attention backend.") logger.info("Using Flash Attention backend.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
...@@ -113,10 +113,10 @@ class XPUPlatform(Platform): ...@@ -113,10 +113,10 @@ class XPUPlatform(Platform):
return device_props.total_memory return device_props.total_memory
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: def get_vit_attn_backend(
from vllm.attention.backends.registry import _Backend cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
return _Backend.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -40,23 +40,16 @@ logger = init_logger(__name__) ...@@ -40,23 +40,16 @@ logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod @classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
attn_impl = _get_paged_attn_impl() attn_impl = _get_paged_attn_impl()
is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) return attn_impl.get_supported_head_sizes()
if not is_valid:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -759,9 +752,8 @@ def _make_sliding_window_bias( ...@@ -759,9 +752,8 @@ def _make_sliding_window_bias(
class _PagedAttention: class _PagedAttention:
@staticmethod @staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]: def get_supported_head_sizes() -> list[int]:
SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] return [32, 64, 80, 96, 112, 128, 192, 256]
return head_size in SUPPORT_HS, SUPPORT_HS
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -861,8 +853,8 @@ class _PagedAttention: ...@@ -861,8 +853,8 @@ class _PagedAttention:
class _IPEXPagedAttention(_PagedAttention): class _IPEXPagedAttention(_PagedAttention):
@staticmethod @staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]: def get_supported_head_sizes() -> list[int]:
return True, [] return []
@staticmethod @staticmethod
def split_kv_cache( def split_kv_cache(
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import numpy as np import numpy as np
import torch import torch
...@@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available(): ...@@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available():
reshape_and_cache_flash, reshape_and_cache_flash,
) )
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
...@@ -52,34 +55,12 @@ logger = init_logger(__name__) ...@@ -52,34 +55,12 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# NOTE(tdoublep): while in principle, FA supports # NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not # MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here: # suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974 # https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64] supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return True
if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto"]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(8, 0)
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
...@@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
FlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes
......
...@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import ( ...@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
MultipleOf, MultipleOf,
) )
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
...@@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
can_use_trtllm_attention, can_use_trtllm_attention,
...@@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
KVCacheLayoutType,
get_kv_cache_layout, get_kv_cache_layout,
get_per_layer_parameters, get_per_layer_parameters,
infer_global_hyperparameters, infer_global_hyperparameters,
...@@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant( ...@@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant(
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# Note: Not sure for all platforms, # Note: Not sure for all platforms,
# but on Blackwell, only support a page size of # but on Blackwell, only support a page size of
# 16, 32, 64 # 16, 32, 64
return [16, 32, 64] supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
@classmethod "auto",
def validate_head_size(cls, head_size: int) -> None: "fp8",
supported_head_sizes = cls.get_supported_head_sizes() "fp8_e4m3",
if head_size not in supported_head_sizes: "fp8_e5m2",
attn_type = cls.__name__.removesuffix("Backend") ]
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend): ...@@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
12, 1
)
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
return "HND"
return None
@dataclass @dataclass
class FlashInferMetadata: class FlashInferMetadata:
...@@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype self.cache_dtype = self.cache_config.cache_dtype
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import torch import torch
import torch._dynamo.decorators import torch._dynamo.decorators
...@@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import ( ...@@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import (
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
...@@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): ...@@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class FlexAttentionBackend(AttentionBackend): class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
@classmethod torch.float16,
def get_supported_dtypes(cls) -> list[torch.dtype]: torch.bfloat16,
return [torch.float16, torch.bfloat16, torch.float32] torch.float32,
]
@classmethod supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend):
def use_cascade_attention(*args, **kwargs) -> bool: def use_cascade_attention(*args, **kwargs) -> bool:
return False return False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
# @torch.compile(fullgraph=True, mode="reduce-overhead") # @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping( def physical_to_logical_mapping(
...@@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl):
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("FlexAttention does not support kv sharing yet.") raise NotImplementedError("FlexAttention does not support kv sharing yet.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet" "FlexAttention does not support quantized kv-cache. Yet"
......
...@@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend): ...@@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [576] return [576]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def is_mla(cls) -> bool:
supported_head_sizes = cls.get_supported_head_sizes() return True
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@dataclass @dataclass
...@@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]): ...@@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]):
) = None ) = None
def __post_init__(self): def __post_init__(self):
if self.head_dim is not None: if self.head_dim is not None and not MLACommonBackend.supports_head_size(
MLACommonBackend.validate_head_size(self.head_dim) self.head_dim
):
raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
......
...@@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import ( ...@@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import (
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
...@@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): ...@@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend): class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "CUTLASS_MLA" return "CUTLASS_MLA"
...@@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend): ...@@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend):
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder return CutlassMLAMetadataBuilder
@staticmethod @classmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [128] return capability.major == 10
class SM100Workspace: class SM100Workspace:
......
...@@ -10,6 +10,7 @@ from vllm import envs ...@@ -10,6 +10,7 @@ from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
...@@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import ( ...@@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import (
get_flash_attn_version, get_flash_attn_version,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -37,6 +40,10 @@ logger = init_logger(__name__) ...@@ -37,6 +40,10 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend): class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_MLA" return "FLASH_ATTN_MLA"
...@@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend): ...@@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashAttnMLAImpl"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl return FlashAttnMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 9
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if not flash_attn_supports_mla():
return "FlashAttention MLA not supported on this device"
return None
@dataclass @dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
......
...@@ -6,8 +6,14 @@ from typing import ClassVar ...@@ -6,8 +6,14 @@ from typing import ClassVar
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
...@@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): ...@@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend): class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHINFER_MLA" return "FLASHINFER_MLA"
...@@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend):
return FlashInferMLAMetadataBuilder return FlashInferMLAMetadataBuilder
@classmethod @classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [32, 64] return capability.major == 10
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
g_fi_workspace = torch.zeros( g_fi_workspace = torch.zeros(
......
...@@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import ( ...@@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import (
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -36,6 +38,14 @@ logger = init_logger(__name__) ...@@ -36,6 +38,14 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend): class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHMLA" return "FLASHMLA"
...@@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend): ...@@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashMLAImpl"]: def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl return FlashMLAImpl
@staticmethod @classmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [64] return capability.major in [9, 10]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
@dataclass @dataclass
......
...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops ...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
MultipleOf,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import ( from vllm.attention.ops.flashmla import (
...@@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import ( ...@@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import (
get_mla_metadata, get_mla_metadata,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
...@@ -51,6 +54,9 @@ structured as: ...@@ -51,6 +54,9 @@ structured as:
class FlashMLASparseBackend(AttentionBackend): class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend): ...@@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend):
def get_impl_cls() -> type["FlashMLASparseImpl"]: def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl return FlashMLASparseImpl
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend): ...@@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend):
else: else:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass @dataclass
class FlashMLASparseMetadata: class FlashMLASparseMetadata:
......
...@@ -23,6 +23,8 @@ logger = init_logger(__name__) ...@@ -23,6 +23,8 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128] return [32, 64, 128]
...@@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend): ...@@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend):
def get_kv_cache_stride_order() -> tuple[int, ...]: def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2) return (0, 1, 2)
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [64]
@dataclass @dataclass
class DeepseekV32IndexerPrefillChunkMetadata: class DeepseekV32IndexerPrefillChunkMetadata:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch import torch
...@@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import ( ...@@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import (
) )
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
...@@ -28,6 +31,9 @@ logger = init_logger(__name__) ...@@ -28,6 +31,9 @@ logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend): class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"
...@@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend): ...@@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend):
def get_impl_cls() -> type["TritonMLAImpl"]: def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl return TritonMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True can_return_lse_for_decode: bool = True
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Attention layer with AiterFlashAttention.""" """Attention layer with AiterFlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import torch import torch
...@@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend): class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
...@@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "
......
...@@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat ...@@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend): class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
...@@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() if not cls.supports_head_size(head_size):
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend") attn_type = cls.__name__.removesuffix("Backend")
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import torch import torch
...@@ -30,31 +30,13 @@ logger = init_logger(__name__) ...@@ -30,31 +30,13 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend): class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TREE_ATTN" return "TREE_ATTN"
...@@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl): ...@@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl):
else: else:
self.sliding_window = (sliding_window - 1, 0) self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "
......
...@@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( ...@@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
...@@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet ...@@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
class TritonAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
@classmethod torch.float16,
def get_supported_dtypes(cls) -> list[torch.dtype]: torch.bfloat16,
return [torch.float16, torch.bfloat16, torch.float32] torch.float32,
]
@staticmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_kernel_block_size() -> list[int | MultipleOf]: supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
return [MultipleOf(16)] "auto",
"fp8",
@classmethod "fp8_e4m3",
def validate_head_size(cls, head_size: int) -> None: "fp8_e5m2",
# Triton Attention supports any head size above 32 ]
if head_size < 32:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
...@@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import torch import torch
...@@ -41,10 +41,8 @@ logger = init_logger(__name__) ...@@ -41,10 +41,8 @@ logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend): class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
...@@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend): ...@@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend):
256, 256,
] ]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "XFORMERS" return "XFORMERS"
...@@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl): ...@@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl):
logits_soft_cap = 0 logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
XFormersAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "
......
...@@ -150,11 +150,15 @@ class EagleProposer: ...@@ -150,11 +150,15 @@ class EagleProposer:
) )
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # ROCM_AITER_FA is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): if find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
):
from vllm.v1.attention.backends.rocm_aiter_fa import ( from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata, AiterFlashAttentionMetadata,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment