Unverified Commit 066209a0 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Attention] Refactor FA `block_size` limitations to hybrid models only (#29084)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 5f7209a7
......@@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
supported_sizes = backend.get_class().supported_kernel_block_sizes
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
......
......@@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
supported_kernel_block_sizes = supported_sizes
@staticmethod
def get_supported_kernel_block_sizes():
return supported_sizes
return _MockBackend()
......
......@@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@staticmethod
@abstractmethod
def get_name() -> str:
......@@ -142,10 +145,11 @@ class AttentionBackend(ABC):
if block_size not in valid_sizes:
return False
if not cls.supported_kernel_block_sizes:
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
for supported_size in cls.supported_kernel_block_sizes:
for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
......
......@@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
......@@ -56,11 +56,26 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
return [16, 32, 64]
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
......
......@@ -16,7 +16,6 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
......@@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
......@@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# Note: Not sure for all platforms, but on Blackwell,
# only support a page size of 16, 32, 64.
return [16, 32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER"
......@@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
@classmethod
@override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
......
......@@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [128]
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
......
......@@ -41,9 +41,12 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
......
......@@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
......
......@@ -39,13 +39,16 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA"
......
......@@ -55,9 +55,12 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
......
......@@ -24,9 +24,9 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
1 if current_platform.is_rocm() else 64
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......
......@@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
@staticmethod
def get_name() -> str:
......
......@@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......
......@@ -31,7 +31,10 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......
......@@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
......@@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
......
......@@ -42,7 +42,10 @@ logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......
......@@ -4618,7 +4618,7 @@ class GPUModelRunner(
"""
for backend in backends:
is_supported = False
for supported_size in backend.supported_kernel_block_sizes:
for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
......@@ -4649,7 +4649,7 @@ class GPUModelRunner(
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.supported_kernel_block_sizes
for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment