Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
066209a0
Unverified
Commit
066209a0
authored
Nov 22, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Nov 22, 2025
Browse files
[Attention] Refactor FA `block_size` limitations to hybrid models only (#29084)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
5f7209a7
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
82 additions
and
32 deletions
+82
-32
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+1
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+3
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+7
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+21
-6
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+6
-6
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+4
-1
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+4
-1
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+4
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+4
-1
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+4
-1
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+3
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+3
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+4
-1
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+4
-1
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+4
-1
vllm/v1/attention/backends/xformers.py
vllm/v1/attention/backends/xformers.py
+4
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
tests/v1/attention/test_mla_backends.py
View file @
066209a0
...
...
@@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES
=
{}
for
backend
in
BACKENDS_TO_TEST
:
supported_sizes
=
backend
.
get_class
().
supported_kernel_block_sizes
supported_sizes
=
backend
.
get_class
().
get_
supported_kernel_block_sizes
()
if
supported_sizes
:
default_size
=
supported_sizes
[
0
]
block_size
=
(
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
066209a0
...
...
@@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes
:
list
[
int
|
MultipleOf
],
):
class
_MockBackend
:
supported_kernel_block_sizes
=
supported_sizes
@
staticmethod
def
get_supported_kernel_block_sizes
():
return
supported_sizes
return
_MockBackend
()
...
...
vllm/attention/backends/abstract.py
View file @
066209a0
...
...
@@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer
:
bool
=
False
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
1
)]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
"CacheDType"
]]
=
[
"auto"
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
1
)]
@
staticmethod
@
abstractmethod
def
get_name
()
->
str
:
...
...
@@ -142,10 +145,11 @@ class AttentionBackend(ABC):
if
block_size
not
in
valid_sizes
:
return
False
if
not
cls
.
supported_kernel_block_sizes
:
supported_kernel_block_sizes
=
cls
.
get_supported_kernel_block_sizes
()
if
not
supported_kernel_block_sizes
:
return
True
for
supported_size
in
cls
.
supported_kernel_block_sizes
:
for
supported_size
in
supported_kernel_block_sizes
:
if
isinstance
(
supported_size
,
MultipleOf
):
supported_size
=
supported_size
.
base
# With hybrid_blocks feature, the framework-level block size
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
066209a0
...
...
@@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata
,
reshape_and_cache_flash
,
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.distributed.parallel_state
import
get_dcp_group
from
vllm.logger
import
init_logger
...
...
@@ -56,11 +56,26 @@ logger = init_logger(__name__)
class
FlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
16
,
32
,
64
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
if
(
model_config
and
model_config
.
is_hybrid
and
(
cache_config
.
mamba_ssm_cache_dtype
==
"float32"
or
cache_config
.
mamba_cache_dtype
==
"float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return
[
16
,
32
,
64
]
return
[
MultipleOf
(
16
)]
@
staticmethod
def
get_name
()
->
str
:
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
066209a0
...
...
@@ -16,7 +16,6 @@ from flashinfer import (
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.utils
import
FP4Tensor
from
typing_extensions
import
override
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
...
...
@@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
class
FlashInferBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
16
,
32
,
64
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8"
,
...
...
@@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
"fp8_e5m2"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
# Note: Not sure for all platforms, but on Blackwell,
# only support a page size of 16, 32, 64.
return
[
16
,
32
,
64
]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHINFER"
...
...
@@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
@
classmethod
@
override
def
get_cudagraph_support
(
cls
:
type
[
"FlashInferMetadataBuilder"
],
vllm_config
:
VllmConfig
,
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
066209a0
...
...
@@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class
CutlassMLABackend
(
MLACommonBackend
):
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
128
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
128
]
@
staticmethod
def
get_name
()
->
str
:
return
"CUTLASS_MLA"
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
066209a0
...
...
@@ -41,9 +41,12 @@ logger = init_logger(__name__)
class
FlashAttnMLABackend
(
MLACommonBackend
):
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
16
)]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASH_ATTN_MLA"
...
...
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
066209a0
...
...
@@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class
FlashInferMLABackend
(
MLACommonBackend
):
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
32
,
64
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
32
,
64
]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHINFER_MLA"
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
066209a0
...
...
@@ -39,13 +39,16 @@ logger = init_logger(__name__)
class
FlashMLABackend
(
MLACommonBackend
):
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
64
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
64
]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHMLA"
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
066209a0
...
...
@@ -55,9 +55,12 @@ structured as:
class
FlashMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
64
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8_ds_mla"
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
64
]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHMLA_SPARSE"
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
066209a0
...
...
@@ -24,9 +24,9 @@ logger = init_logger(__name__)
class
DeepseekV32IndexerBackend
(
AttentionBackend
):
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
1
if
current_platform
.
is_rocm
()
else
64
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
1
if
current_platform
.
is_rocm
()
else
64
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
066209a0
...
...
@@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class
AiterMLABackend
(
MLACommonBackend
):
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
1
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
1
]
@
staticmethod
def
get_name
()
->
str
:
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
066209a0
...
...
@@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
class
AiterFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
16
)]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
vllm/v1/attention/backends/tree_attn.py
View file @
066209a0
...
...
@@ -31,7 +31,10 @@ logger = init_logger(__name__)
class
TreeAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
16
)]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
066209a0
...
...
@@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
torch
.
bfloat16
,
torch
.
float32
,
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
16
)]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"fp8"
,
...
...
@@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
"fp8_e5m2"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
staticmethod
def
get_name
()
->
str
:
return
"TRITON_ATTN"
...
...
vllm/v1/attention/backends/xformers.py
View file @
066209a0
...
...
@@ -42,7 +42,10 @@ logger = init_logger(__name__)
class
XFormersAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
MultipleOf
(
16
)]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
066209a0
...
...
@@ -4618,7 +4618,7 @@ class GPUModelRunner(
"""
for
backend
in
backends
:
is_supported
=
False
for
supported_size
in
backend
.
supported_kernel_block_sizes
:
for
supported_size
in
backend
.
get_
supported_kernel_block_sizes
()
:
if
isinstance
(
supported_size
,
int
):
if
block_size
==
supported_size
:
is_supported
=
True
...
...
@@ -4649,7 +4649,7 @@ class GPUModelRunner(
all_int_supported_sizes
=
set
(
supported_size
for
backend
in
backends
for
supported_size
in
backend
.
supported_kernel_block_sizes
for
supported_size
in
backend
.
get_
supported_kernel_block_sizes
()
if
isinstance
(
supported_size
,
int
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment