Unverified Commit b30dfa03 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor CUDA attention backend selection logic (#24794)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2e78150d
......@@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
else:
BlockSize = None
ModelConfig = None
VllmConfig = None
PoolingParams = None
_Backend = None
AttentionBackendEnum = None
logger = init_logger(__name__)
......@@ -54,7 +53,7 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
......@@ -64,17 +63,17 @@ class TpuPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS:
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1:
raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def set_device(cls, device: torch.device) -> None:
......
......@@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
ModelConfig = None
VllmConfig = None
_Backend = None
AttentionBackendEnum = None
logger = init_logger(__name__)
......@@ -44,7 +43,7 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
......@@ -62,18 +61,19 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN:
use_v1 = envs.VLLM_USE_V1
if not use_v1:
raise ValueError("XPU backend only supports V1.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN
return AttentionBackendEnum.FLASH_ATTN.get_path()
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
......@@ -81,7 +81,7 @@ class XPUPlatform(Platform):
)
logger.info("Using Flash Attention backend.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def set_device(cls, device: torch.device) -> None:
......@@ -113,10 +113,10 @@ class XPUPlatform(Platform):
return device_props.total_memory
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend
return _Backend.FLASH_ATTN
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def inference_mode(cls):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
from typing import ClassVar, Optional
import numpy as np
import torch
......@@ -40,23 +40,16 @@ logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
def get_supported_head_sizes(cls) -> list[int]:
attn_impl = _get_paged_attn_impl()
is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size)
if not is_valid:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
return attn_impl.get_supported_head_sizes()
@staticmethod
def get_name() -> str:
......@@ -759,9 +752,8 @@ def _make_sliding_window_bias(
class _PagedAttention:
@staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256]
return head_size in SUPPORT_HS, SUPPORT_HS
def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
......@@ -861,8 +853,8 @@ class _PagedAttention:
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
return True, []
def get_supported_head_sizes() -> list[int]:
return []
@staticmethod
def split_kv_cache(
......
......@@ -3,6 +3,7 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import ClassVar
import numpy as np
import torch
......@@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available():
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
......@@ -52,34 +55,12 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
@staticmethod
def get_name() -> str:
......@@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend):
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return True
if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto"]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(8, 0)
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass
class FlashAttentionMetadata:
......@@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
FlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
......
......@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
MultipleOf,
)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
......@@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import (
can_use_trtllm_attention,
......@@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
KVCacheLayoutType,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters,
......@@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant(
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_name() -> str:
......@@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend):
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
12, 1
)
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
return "HND"
return None
@dataclass
class FlashInferMetadata:
......@@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype
......
......@@ -4,6 +4,7 @@
import math
from dataclasses import dataclass
from typing import ClassVar
import torch
import torch._dynamo.decorators
......@@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import (
is_quantized_kv_cache,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
......@@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_name() -> str:
......@@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend):
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
# @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(
......@@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl):
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("FlexAttention does not support kv sharing yet.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet"
......
......@@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
def is_mla(cls) -> bool:
return True
@dataclass
......@@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]):
) = None
def __post_init__(self):
if self.head_dim is not None:
MLACommonBackend.validate_head_size(self.head_dim)
if self.head_dim is not None and not MLACommonBackend.supports_head_size(
self.head_dim
):
raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
M = TypeVar("M", bound=MLACommonMetadata)
......
......@@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import (
MultipleOf,
is_quantized_kv_cache,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
......@@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
......@@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend):
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [128]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
class SM100Workspace:
......
......@@ -10,6 +10,7 @@ from vllm import envs
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.attention.utils.fa_utils import (
......@@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import (
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
......@@ -37,6 +40,10 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
......@@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 9
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if not flash_attn_supports_mla():
return "FlashAttention MLA not supported on this device"
return None
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
......
......@@ -6,8 +6,14 @@ from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
......@@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
logger = init_logger(__name__)
......@@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
......@@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend):
return FlashInferMLAMetadataBuilder
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [32, 64]
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
g_fi_workspace = torch.zeros(
......
......@@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import (
is_flashmla_dense_supported,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
......@@ -36,6 +38,14 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_name() -> str:
return "FLASHMLA"
......@@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [64]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
@dataclass
......
......@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
MultipleOf,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import (
......@@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import (
get_mla_metadata,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
......@@ -51,6 +54,9 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
@staticmethod
def get_name() -> str:
......@@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend):
def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend):
else:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class FlashMLASparseMetadata:
......
......@@ -23,6 +23,8 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128]
......@@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend):
def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2)
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [64]
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
......@@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
......@@ -28,6 +31,9 @@ logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
......@@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend):
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
......
......@@ -3,6 +3,7 @@
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import ClassVar
import torch
......@@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
......@@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
......
......@@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......@@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
if not cls.supports_head_size(head_size):
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
......
......@@ -4,7 +4,7 @@
import ast
from dataclasses import dataclass
from typing import Optional
from typing import ClassVar, Optional
import torch
......@@ -30,31 +30,13 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod
def get_name() -> str:
return "TREE_ATTN"
......@@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
......
......@@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......@@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
# Triton Attention supports any head size above 32
if head_size < 32:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_name() -> str:
......@@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
......@@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
......
......@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import ClassVar, Optional
import torch
......@@ -41,10 +41,8 @@ logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......@@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend):
256,
]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod
def get_name() -> str:
return "XFORMERS"
......@@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl):
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
XFormersAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
......
......@@ -150,11 +150,15 @@ class EagleProposer:
)
# Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
# ROCM_AITER_FA is an optional backend
if find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
):
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment