Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f444c05c
Unverified
Commit
f444c05c
authored
Mar 12, 2026
by
Matthew Bonanni
Committed by
GitHub
Mar 12, 2026
Browse files
[Attention] Use FA4 for MLA prefill (#34732)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
85199f96
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
413 additions
and
78 deletions
+413
-78
benchmarks/attention_benchmarks/benchmark.py
benchmarks/attention_benchmarks/benchmark.py
+103
-29
benchmarks/attention_benchmarks/common.py
benchmarks/attention_benchmarks/common.py
+2
-0
benchmarks/attention_benchmarks/configs/mla_prefill.yaml
benchmarks/attention_benchmarks/configs/mla_prefill.yaml
+89
-25
benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml
...arks/attention_benchmarks/configs/mla_sparse_prefill.yaml
+62
-0
benchmarks/attention_benchmarks/mla_runner.py
benchmarks/attention_benchmarks/mla_runner.py
+143
-14
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
vllm/config/attention.py
vllm/config/attention.py
+2
-2
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+8
-7
vllm/v1/attention/backends/fa_utils.py
vllm/v1/attention/backends/fa_utils.py
+3
-0
No files found.
benchmarks/attention_benchmarks/benchmark.py
View file @
f444c05c
...
...
@@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""Run MLA benchmark with appropriate backend."""
from
mla_runner
import
run_mla_benchmark
as
run_mla
return
run_mla
(
config
.
backend
,
config
,
**
kwargs
)
return
run_mla
(
config
.
backend
,
config
,
prefill_backend
=
config
.
prefill_backend
,
**
kwargs
)
def
run_benchmark
(
config
:
BenchmarkConfig
,
**
kwargs
)
->
BenchmarkResult
:
...
...
@@ -440,14 +442,21 @@ def main():
# Backend selection
parser
.
add_argument
(
"--backends"
,
"--decode-backends"
,
nargs
=
"+"
,
help
=
"
B
ackends to benchmark (flash, triton, flashinfer, cutlass_mla, "
help
=
"
Decode b
ackends to benchmark (flash, triton, flashinfer, cutlass_mla, "
"flashinfer_mla, flashattn_mla, flashmla)"
,
)
parser
.
add_argument
(
"--backend"
,
help
=
"Single backend (alternative to --backends)"
,
)
parser
.
add_argument
(
"--prefill-backends"
,
nargs
=
"+"
,
help
=
"Prefill backends to compare (fa2, fa3, fa4). "
"Uses the first decode backend for impl construction."
,
)
# Batch specifications
parser
.
add_argument
(
...
...
@@ -502,7 +511,7 @@ def main():
# Override args with YAML values, but CLI args take precedence
# Check if CLI provided backends (they would be non-None and not default)
cli_backends_provided
=
args
.
backend
s
is
not
None
or
args
.
backend
is
not
None
cli_backends_provided
=
args
.
backend
is
not
None
or
args
.
backend
s
is
not
None
# Backend(s) - only use YAML if CLI didn't specify
if
not
cli_backends_provided
:
...
...
@@ -512,6 +521,12 @@ def main():
elif
"backends"
in
yaml_config
:
args
.
backends
=
yaml_config
[
"backends"
]
args
.
backend
=
None
elif
"decode_backends"
in
yaml_config
:
args
.
backends
=
yaml_config
[
"decode_backends"
]
args
.
backend
=
None
# Prefill backends (e.g., ["fa3", "fa4"])
args
.
prefill_backends
=
yaml_config
.
get
(
"prefill_backends"
,
None
)
# Check for special modes
if
"mode"
in
yaml_config
:
...
...
@@ -613,7 +628,10 @@ def main():
# Determine backends
backends
=
args
.
backends
or
([
args
.
backend
]
if
args
.
backend
else
[
"flash"
])
prefill_backends
=
getattr
(
args
,
"prefill_backends"
,
None
)
console
.
print
(
f
"Backends:
{
', '
.
join
(
backends
)
}
"
)
if
prefill_backends
:
console
.
print
(
f
"Prefill backends:
{
', '
.
join
(
prefill_backends
)
}
"
)
console
.
print
(
f
"Batch specs:
{
', '
.
join
(
args
.
batch_specs
)
}
"
)
console
.
print
()
...
...
@@ -850,37 +868,93 @@ def main():
else
:
# Normal mode: compare backends
total
=
len
(
backends
)
*
len
(
args
.
batch_specs
)
decode_results
=
[]
prefill_results
=
[]
with
tqdm
(
total
=
total
,
desc
=
"Benchmarking"
)
as
pbar
:
for
spec
in
args
.
batch_specs
:
for
backend
in
backends
:
config
=
BenchmarkConfig
(
backend
=
backend
,
batch_spec
=
spec
,
num_layers
=
args
.
num_layers
,
head_dim
=
args
.
head_dim
,
num_q_heads
=
args
.
num_q_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
block_size
=
args
.
block_size
,
device
=
args
.
device
,
repeats
=
args
.
repeats
,
warmup_iters
=
args
.
warmup_iters
,
profile_memory
=
args
.
profile_memory
,
)
# Run decode backend comparison
if
not
prefill_backends
:
# No prefill backends specified: compare decode backends as before
total
=
len
(
backends
)
*
len
(
args
.
batch_specs
)
result
=
run_benchmark
(
config
)
all_results
.
append
(
result
)
with
tqdm
(
total
=
total
,
desc
=
"Benchmarking"
)
as
pbar
:
for
spec
in
args
.
batch_specs
:
for
backend
in
backends
:
config
=
BenchmarkConfig
(
backend
=
backend
,
batch_spec
=
spec
,
num_layers
=
args
.
num_layers
,
head_dim
=
args
.
head_dim
,
num_q_heads
=
args
.
num_q_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
block_size
=
args
.
block_size
,
device
=
args
.
device
,
repeats
=
args
.
repeats
,
warmup_iters
=
args
.
warmup_iters
,
profile_memory
=
args
.
profile_memory
,
)
if
not
result
.
success
:
console
.
print
(
f
"[red]Error
{
backend
}
{
spec
}
:
{
result
.
error
}
[/]"
)
result
=
run_benchmark
(
config
)
decode_results
.
append
(
result
)
pbar
.
update
(
1
)
if
not
result
.
success
:
console
.
print
(
f
"[red]Error
{
backend
}
{
spec
}
:
{
result
.
error
}
[/]"
)
# Display results
console
.
print
(
"
\n
[bold green]Results:[/]"
)
formatter
=
ResultsFormatter
(
console
)
formatter
.
print_table
(
all_results
,
backends
)
pbar
.
update
(
1
)
console
.
print
(
"
\n
[bold green]Results:[/]"
)
formatter
=
ResultsFormatter
(
console
)
formatter
.
print_table
(
decode_results
,
backends
)
# Run prefill backend comparison
if
prefill_backends
:
# Use first decode backend for impl construction
decode_backend
=
backends
[
0
]
total
=
len
(
prefill_backends
)
*
len
(
args
.
batch_specs
)
console
.
print
(
f
"[yellow]Prefill comparison mode: "
f
"using
{
decode_backend
}
for decode impl[/]"
)
with
tqdm
(
total
=
total
,
desc
=
"Prefill benchmarking"
)
as
pbar
:
for
spec
in
args
.
batch_specs
:
for
pb
in
prefill_backends
:
config
=
BenchmarkConfig
(
backend
=
decode_backend
,
batch_spec
=
spec
,
num_layers
=
args
.
num_layers
,
head_dim
=
args
.
head_dim
,
num_q_heads
=
args
.
num_q_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
block_size
=
args
.
block_size
,
device
=
args
.
device
,
repeats
=
args
.
repeats
,
warmup_iters
=
args
.
warmup_iters
,
profile_memory
=
args
.
profile_memory
,
prefill_backend
=
pb
,
)
result
=
run_benchmark
(
config
)
# Label result with prefill backend name for display
labeled_config
=
replace
(
result
.
config
,
backend
=
pb
)
result
=
replace
(
result
,
config
=
labeled_config
)
prefill_results
.
append
(
result
)
if
not
result
.
success
:
console
.
print
(
f
"[red]Error
{
pb
}
{
spec
}
:
{
result
.
error
}
[/]"
)
pbar
.
update
(
1
)
console
.
print
(
"
\n
[bold green]Prefill Backend Results:[/]"
)
formatter
=
ResultsFormatter
(
console
)
formatter
.
print_table
(
prefill_results
,
prefill_backends
,
compare_to_fastest
=
True
)
all_results
=
decode_results
+
prefill_results
# Save results
if
all_results
:
...
...
benchmarks/attention_benchmarks/common.py
View file @
f444c05c
...
...
@@ -77,6 +77,7 @@ class MockKVBProj:
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
out_dim
=
qk_nope_head_dim
+
v_head_dim
self
.
weight
=
torch
.
empty
(
0
,
dtype
=
torch
.
bfloat16
)
def
__call__
(
self
,
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
]:
"""
...
...
@@ -213,6 +214,7 @@ class BenchmarkConfig:
use_cuda_graphs
:
bool
=
False
# MLA-specific
prefill_backend
:
str
|
None
=
None
kv_lora_rank
:
int
|
None
=
None
qk_nope_head_dim
:
int
|
None
=
None
qk_rope_head_dim
:
int
|
None
=
None
...
...
benchmarks/attention_benchmarks/configs/mla_prefill.yaml
View file @
f444c05c
# MLA prefill-only benchmark configuration for sparse backends
# MLA prefill backend comparison
#
# Compares all available MLA prefill backends:
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
#
# Uses cutlass_mla as the decode backend for impl construction
# (only the prefill path is exercised).
#
# Backends that aren't available on the current platform will report errors
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
#
# Usage:
# python benchmark.py --config configs/mla_prefill.yaml
description
:
"
MLA
prefill
backend
comparison"
model
:
name
:
"
deepseek-v3"
...
...
@@ -12,20 +27,25 @@ model:
v_head_dim
:
128
block_size
:
128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep
:
param_name
:
"
num_q_heads"
values
:
[
128
,
64
,
32
,
16
]
label_format
:
"
{backend}_{value}h"
# model:
# name: "deepseek-v2-lite"
# num_layers: 27
# num_q_heads: 16
# num_kv_heads: 1
# head_dim: 576
# kv_lora_rank: 512
# qk_nope_head_dim: 128
# qk_rope_head_dim: 64
# v_head_dim: 128
# block_size: 128
batch_specs
:
# Pure prefill
-
"
1
q512"
-
"
1
q1k"
-
"
1
q2k"
-
"
1
q4k"
-
"
1
q8k"
-
"
q512"
-
"
q1k"
-
"
q2k"
-
"
q4k"
-
"
q8k"
# Batched pure prefill
-
"
2q512"
...
...
@@ -44,19 +64,63 @@ batch_specs:
-
"
8q4k"
-
"
8q8k"
# Extend
-
"
1q512s4k"
-
"
1q512s8k"
-
"
1q1ks8k"
-
"
1q2ks8k"
-
"
1q2ks16k"
-
"
1q4ks16k"
# Chunked prefill / extend
# Short context
-
"
q128s1k"
-
"
q256s2k"
-
"
q512s4k"
-
"
q1ks4k"
-
"
q2ks8k"
-
"
2q128s1k"
-
"
2q256s2k"
-
"
2q512s4k"
-
"
2q1ks4k"
-
"
2q2ks8k"
-
"
4q128s1k"
-
"
4q256s2k"
-
"
4q512s4k"
-
"
4q1ks4k"
-
"
4q2ks8k"
-
"
8q128s1k"
-
"
8q256s2k"
-
"
8q512s4k"
-
"
8q1ks4k"
# Medium context
-
"
q128s16k"
-
"
q512s16k"
-
"
q1ks16k"
-
"
q2ks16k"
-
"
2q128s16k"
-
"
2q512s16k"
-
"
2q1ks16k"
-
"
2q2ks16k"
-
"
4q128s16k"
-
"
4q512s16k"
-
"
4q1ks16k"
-
"
4q2ks16k"
# Long context
-
"
q128s64k"
-
"
q512s64k"
-
"
q1ks64k"
-
"
q2ks64k"
-
"
2q128s64k"
-
"
2q512s64k"
-
"
2q1ks64k"
-
"
2q2ks64k"
decode_backends
:
-
CUTLASS_MLA
backends
:
-
FLASHMLA_SPARSE
-
FLASHINFER_MLA_SPARSE
prefill_backends
:
-
fa2
-
fa3
-
fa4
-
flashinfer
-
cudnn
-
trtllm
device
:
"
cuda:0"
repeats
:
10
warmup_iters
:
3
profile_memory
:
true
repeats
:
20
warmup_iters
:
5
benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml
0 → 100644
View file @
f444c05c
# MLA prefill-only benchmark configuration for sparse backends
model
:
name
:
"
deepseek-v3"
num_layers
:
60
num_q_heads
:
128
num_kv_heads
:
1
head_dim
:
576
kv_lora_rank
:
512
qk_nope_head_dim
:
128
qk_rope_head_dim
:
64
v_head_dim
:
128
block_size
:
128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep
:
param_name
:
"
num_q_heads"
values
:
[
128
,
64
,
32
,
16
]
label_format
:
"
{backend}_{value}h"
batch_specs
:
# Pure prefill
-
"
1q512"
-
"
1q1k"
-
"
1q2k"
-
"
1q4k"
-
"
1q8k"
# Batched pure prefill
-
"
2q512"
-
"
2q1k"
-
"
2q2k"
-
"
2q4k"
-
"
2q8k"
-
"
4q512"
-
"
4q1k"
-
"
4q2k"
-
"
4q4k"
-
"
4q8k"
-
"
8q512"
-
"
8q1k"
-
"
8q2k"
-
"
8q4k"
-
"
8q8k"
# Extend
-
"
1q512s4k"
-
"
1q512s8k"
-
"
1q1ks8k"
-
"
1q2ks8k"
-
"
1q2ks16k"
-
"
1q4ks16k"
backends
:
-
FLASHMLA_SPARSE
-
FLASHINFER_MLA_SPARSE
device
:
"
cuda:0"
repeats
:
10
warmup_iters
:
3
profile_memory
:
true
benchmarks/attention_benchmarks/mla_runner.py
View file @
f444c05c
...
...
@@ -62,6 +62,7 @@ def create_minimal_vllm_config(
max_num_seqs
:
int
=
256
,
mla_dims
:
dict
|
None
=
None
,
index_topk
:
int
|
None
=
None
,
prefill_backend
:
str
|
None
=
None
,
)
->
VllmConfig
:
"""
Create minimal VllmConfig for MLA benchmarks.
...
...
@@ -75,6 +76,9 @@ def create_minimal_vllm_config(
setup_mla_dims(model_name)
index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention.
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to
force the specified prefill backend.
Returns:
VllmConfig for benchmarking
...
...
@@ -163,7 +167,7 @@ def create_minimal_vllm_config(
compilation_config
=
CompilationConfig
()
return
VllmConfig
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
...
...
@@ -171,9 +175,84 @@ def create_minimal_vllm_config(
compilation_config
=
compilation_config
,
)
if
prefill_backend
is
not
None
:
prefill_cfg
=
get_prefill_backend_config
(
prefill_backend
)
if
prefill_cfg
[
"flash_attn_version"
]
is
not
None
:
vllm_config
.
attention_config
.
flash_attn_version
=
prefill_cfg
[
"flash_attn_version"
]
vllm_config
.
attention_config
.
disable_flashinfer_prefill
=
prefill_cfg
[
"disable_flashinfer_prefill"
]
vllm_config
.
attention_config
.
use_cudnn_prefill
=
prefill_cfg
[
"use_cudnn_prefill"
]
vllm_config
.
attention_config
.
use_trtllm_ragged_deepseek_prefill
=
prefill_cfg
[
"use_trtllm_ragged_deepseek_prefill"
]
return
vllm_config
# ============================================================================
# Backend Configuration
# Prefill Backend Configuration
# ============================================================================
# Maps prefill backend names to attention config overrides.
# FA backends set flash_attn_version and disable non-FA paths.
# Non-FA backends enable their specific path and disable others.
_PREFILL_BACKEND_CONFIG
:
dict
[
str
,
dict
]
=
{
"fa2"
:
{
"flash_attn_version"
:
2
,
"disable_flashinfer_prefill"
:
True
,
"use_cudnn_prefill"
:
False
,
"use_trtllm_ragged_deepseek_prefill"
:
False
,
},
"fa3"
:
{
"flash_attn_version"
:
3
,
"disable_flashinfer_prefill"
:
True
,
"use_cudnn_prefill"
:
False
,
"use_trtllm_ragged_deepseek_prefill"
:
False
,
},
"fa4"
:
{
"flash_attn_version"
:
4
,
"disable_flashinfer_prefill"
:
True
,
"use_cudnn_prefill"
:
False
,
"use_trtllm_ragged_deepseek_prefill"
:
False
,
},
"flashinfer"
:
{
"flash_attn_version"
:
None
,
"disable_flashinfer_prefill"
:
False
,
"use_cudnn_prefill"
:
False
,
"use_trtllm_ragged_deepseek_prefill"
:
False
,
},
"cudnn"
:
{
"flash_attn_version"
:
None
,
"disable_flashinfer_prefill"
:
True
,
"use_cudnn_prefill"
:
True
,
"use_trtllm_ragged_deepseek_prefill"
:
False
,
},
"trtllm"
:
{
"flash_attn_version"
:
None
,
"disable_flashinfer_prefill"
:
True
,
"use_cudnn_prefill"
:
False
,
"use_trtllm_ragged_deepseek_prefill"
:
True
,
},
}
def
get_prefill_backend_config
(
prefill_backend
:
str
)
->
dict
:
"""Get attention config overrides for a prefill backend."""
if
prefill_backend
not
in
_PREFILL_BACKEND_CONFIG
:
raise
ValueError
(
f
"Unknown prefill backend:
{
prefill_backend
!
r
}
. "
f
"Available:
{
list
(
_PREFILL_BACKEND_CONFIG
.
keys
())
}
"
)
return
_PREFILL_BACKEND_CONFIG
[
prefill_backend
]
# ============================================================================
# Decode Backend Configuration
# ============================================================================
...
...
@@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict:
Returns:
Dict with backend configuration
"""
from
vllm.v1.attention.backend
import
MultipleOf
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
try
:
...
...
@@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict:
block_sizes
=
backend_class
.
get_supported_kernel_block_sizes
()
# Use first supported block size (backends typically support one for MLA)
block_size
=
block_sizes
[
0
]
if
block_sizes
else
None
if
hasattr
(
block_size
,
"value"
):
#
Handle MultipleOf enum
if
isinstance
(
block_size
,
MultipleOf
):
#
No fixed block size; fall back to config value
block_size
=
None
# Check if sparse via class method if available
...
...
@@ -676,16 +756,11 @@ def _run_single_benchmark(
if
is_sparse
and
indexer
is
not
None
:
indexer
.
fill_random_indices
(
total_q
,
max_kv_len
)
# Determine which forward method to use
if
is_sparse
:
# Sparse backends use forward_mqa
# Determine which forward method to use based on metadata
if
metadata
.
decode
is
not
None
:
forward_fn
=
lambda
:
impl
.
forward_mqa
(
decode_inputs
,
kv_cache
,
metadata
,
layer
)
elif
metadata
.
decode
is
not
None
:
forward_fn
=
lambda
:
impl
.
_forward_decode
(
decode_inputs
,
kv_cache
,
metadata
,
layer
)
elif
metadata
.
prefill
is
not
None
:
forward_fn
=
lambda
:
impl
.
_
forward_
prefill
(
forward_fn
=
lambda
:
impl
.
forward_
mha
(
prefill_inputs
[
"q"
],
prefill_inputs
[
"k_c_normed"
],
prefill_inputs
[
"k_pe"
],
...
...
@@ -732,6 +807,7 @@ def _run_mla_benchmark_batched(
backend
:
str
,
configs_with_params
:
list
[
tuple
],
# [(config, threshold, num_splits), ...]
index_topk
:
int
=
2048
,
prefill_backend
:
str
|
None
=
None
,
)
->
list
[
BenchmarkResult
]:
"""
Unified batched MLA benchmark runner for all backends.
...
...
@@ -743,11 +819,13 @@ def _run_mla_benchmark_batched(
to avoid setup/teardown overhead.
Args:
backend: Backend name
backend: Backend name
(decode backend used for impl construction)
configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only)
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns:
List of BenchmarkResult objects
...
...
@@ -780,11 +858,25 @@ def _run_mla_benchmark_batched(
block_size
=
block_size
,
mla_dims
=
mla_dims
,
# Use custom dims from config or default
index_topk
=
index_topk
if
is_sparse
else
None
,
prefill_backend
=
prefill_backend
,
)
results
=
[]
with
set_current_vllm_config
(
vllm_config
):
# Clear cached prefill backend detection functions so they re-evaluate
# with the current VllmConfig. These are @functools.cache decorated and
# would otherwise return stale results from a previous backend's config.
from
vllm.model_executor.layers.attention.mla_attention
import
(
use_cudnn_prefill
,
use_flashinfer_prefill
,
use_trtllm_ragged_deepseek_prefill
,
)
use_flashinfer_prefill
.
cache_clear
()
use_cudnn_prefill
.
cache_clear
()
use_trtllm_ragged_deepseek_prefill
.
cache_clear
()
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
impl
,
layer
,
builder_instance
,
indexer
=
_create_backend_impl
(
backend_cfg
,
...
...
@@ -794,6 +886,38 @@ def _run_mla_benchmark_batched(
index_topk
=
index_topk
if
is_sparse
else
None
,
)
# Verify the actual prefill backend matches what was requested
if
prefill_backend
is
not
None
:
prefill_cfg
=
get_prefill_backend_config
(
prefill_backend
)
fa_version
=
prefill_cfg
[
"flash_attn_version"
]
if
fa_version
is
not
None
:
# FA backend: verify the impl's FA version
actual_fa_version
=
getattr
(
impl
,
"vllm_flash_attn_version"
,
None
)
if
actual_fa_version
!=
fa_version
:
raise
RuntimeError
(
f
"Prefill backend '
{
prefill_backend
}
' requested FA "
f
"version
{
fa_version
}
, but the impl is using FA "
f
"version
{
actual_fa_version
}
. Check "
f
"vllm/v1/attention/backends/fa_utils.py."
)
else
:
# Non-FA backend: verify the builder picked the right path
expected_flags
=
{
"flashinfer"
:
"_use_fi_prefill"
,
"cudnn"
:
"_use_cudnn_prefill"
,
"trtllm"
:
"_use_trtllm_ragged_prefill"
,
}
flag_name
=
expected_flags
.
get
(
prefill_backend
)
if
flag_name
and
not
getattr
(
builder_instance
,
flag_name
,
False
):
raise
RuntimeError
(
f
"Prefill backend '
{
prefill_backend
}
' was requested "
f
"but the metadata builder did not enable it. This "
f
"usually means a dependency is missing (e.g., "
f
"flashinfer not installed) or the platform doesn't "
f
"support it."
)
# Run each benchmark with the shared impl
for
config
,
threshold
,
num_splits
in
configs_with_params
:
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
...
...
@@ -844,6 +968,7 @@ def run_mla_benchmark(
reorder_batch_threshold
:
int
|
None
=
None
,
num_kv_splits
:
int
|
None
=
None
,
index_topk
:
int
=
2048
,
prefill_backend
:
str
|
None
=
None
,
)
->
BenchmarkResult
|
list
[
BenchmarkResult
]:
"""
Unified MLA benchmark runner for all backends.
...
...
@@ -861,6 +986,8 @@ def run_mla_benchmark(
(single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
...
...
@@ -884,7 +1011,9 @@ def run_mla_benchmark(
return_single
=
True
# Use unified batched execution
results
=
_run_mla_benchmark_batched
(
backend
,
configs_with_params
,
index_topk
)
results
=
_run_mla_benchmark_batched
(
backend
,
configs_with_params
,
index_topk
,
prefill_backend
=
prefill_backend
)
# Return single result or list based on input
return
results
[
0
]
if
return_single
else
results
cmake/external_projects/vllm_flash_attn.cmake
View file @
f444c05c
...
...
@@ -39,7 +39,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 14
0c00c0241bb60cc6e44e7c1be9998d4b20d8d2
GIT_TAG 14
88682bb545f7d020e958a33116b1419d1cfc83
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
vllm/config/attention.py
View file @
f444c05c
...
...
@@ -30,14 +30,14 @@ class AttentionConfig:
use_cudnn_prefill
:
bool
=
False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill
:
bool
=
Tru
e
use_trtllm_ragged_deepseek_prefill
:
bool
=
Fals
e
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention
:
bool
|
None
=
None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill
:
bool
=
Fals
e
disable_flashinfer_prefill
:
bool
=
Tru
e
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization
:
bool
=
False
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
f444c05c
...
...
@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
@
functools
.
cache
def
use_flashinfer_prefill
()
->
bool
:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
...
...
@@ -2154,13 +2152,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
#
We don't need to pad V if we are on a hopper system with FA3
# not support different headdims
.
#
FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
device_capability
=
current_platform
.
get_device_capability
()
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
device_capability
is
not
None
and
device_capability
[
0
]
==
9
(
self
.
vllm_flash_attn_version
==
3
and
device_capability
is
not
None
and
device_capability
[
0
]
==
9
)
or
self
.
vllm_flash_attn_version
==
4
)
self
.
dcp_world_size
:
int
=
-
1
...
...
vllm/v1/attention/backends/fa_utils.py
View file @
f444c05c
...
...
@@ -125,11 +125,14 @@ def get_flash_attn_version(
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
# Exception: hdim 192 is supported for MLA's diff-headdim case
# (qk=192, v=128), added upstream in commits 1a15733e/1b36ab19.
if
(
fa_version
==
4
and
device_capability
.
major
>=
10
and
head_size
is
not
None
and
head_size
>
128
and
head_size
!=
192
):
logger
.
warning_once
(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment