Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82af928c
Unverified
Commit
82af928c
authored
Oct 14, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 14, 2025
Browse files
[Attention][Spec Decode] FlashMLA spec decode support (#26541)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
87efc681
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
91 deletions
+214
-91
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+156
-71
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+37
-12
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+3
-2
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+2
-4
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+16
-2
No files found.
tests/v1/attention/test_mla_backends.py
View file @
82af928c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
"""Tests for v1 MLA backends without GPUModelRunner dependency.
Known Issues:
- FLASH_ATTN_MLA backend occasionally produces NaN values in
test_backend_correctness[mixed_small] when run after
test_backend_correctness[small_prefill], but passes when run alone.
"""
import
pytest
import
pytest
import
torch
import
torch
...
@@ -14,6 +20,8 @@ from tests.v1.attention.utils import (
...
@@ -14,6 +20,8 @@ from tests.v1.attention.utils import (
)
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.ops.flashmla
import
is_flashmla_dense_supported
from
vllm.config.vllm
import
set_current_vllm_config
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
...
@@ -29,6 +37,10 @@ BACKENDS_TO_TEST = [
...
@@ -29,6 +37,10 @@ BACKENDS_TO_TEST = [
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
BACKENDS_TO_TEST
.
remove
(
_Backend
.
CUTLASS_MLA
)
BACKENDS_TO_TEST
.
remove
(
_Backend
.
CUTLASS_MLA
)
# Remove FLASHMLA from the list if not supported
if
not
is_flashmla_dense_supported
()[
0
]:
BACKENDS_TO_TEST
.
remove
(
_Backend
.
FLASHMLA
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -66,6 +78,12 @@ BATCH_SPECS = {
...
@@ -66,6 +78,12 @@ BATCH_SPECS = {
"large_prefill"
:
BatchSpec
(
seq_lens
=
[
4096
]
*
8
,
query_lens
=
[
32
]
*
8
),
"large_prefill"
:
BatchSpec
(
seq_lens
=
[
4096
]
*
8
,
query_lens
=
[
32
]
*
8
),
"single_decode"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
1
]),
"single_decode"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
1
]),
"single_prefill"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
64
]),
"single_prefill"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
64
]),
"spec_decode_small"
:
BatchSpec
(
seq_lens
=
[
128
,
256
,
512
,
1024
],
query_lens
=
[
4
,
4
,
4
,
4
]
),
"spec_decode_medium"
:
BatchSpec
(
seq_lens
=
[
512
,
1024
,
2048
,
512
,
1024
,
2048
],
query_lens
=
[
8
,
8
,
8
,
8
,
8
,
8
]
),
}
}
...
@@ -239,61 +257,64 @@ def run_attention_backend(
...
@@ -239,61 +257,64 @@ def run_attention_backend(
builder_cls
,
impl_cls
=
try_get_attention_backend
(
backend
)
builder_cls
,
impl_cls
=
try_get_attention_backend
(
backend
)
# Build metadata
# Set the current vllm config so that get_current_vllm_config() works
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
# in the backend implementations
attn_metadata
=
builder
.
build
(
with
set_current_vllm_config
(
vllm_config
):
common_prefix_len
=
0
,
# Build metadata
common_attn_metadata
=
common_attn_metadata
,
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
)
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
# Instantiate MLA implementation
# Instantiate MLA implementation
num_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
num_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
vllm_config
.
parallel_config
)
)
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
vllm_config
.
parallel_config
)
)
head_size
=
vllm_config
.
model_config
.
get_head_size
()
head_size
=
vllm_config
.
model_config
.
get_head_size
()
scale
=
1.0
/
(
head_size
**
0.5
)
scale
=
1.0
/
(
head_size
**
0.5
)
impl
=
impl_cls
(
impl
=
impl_cls
(
num_heads
=
num_heads
,
num_heads
=
num_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
scale
=
scale
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
logits_soft_cap
=
None
,
logits_soft_cap
=
None
,
attn_type
=
"decoder"
,
attn_type
=
"decoder"
,
kv_sharing_target_layer_name
=
None
,
kv_sharing_target_layer_name
=
None
,
q_lora_rank
=
None
,
q_lora_rank
=
None
,
kv_lora_rank
=
kv_lora_rank
,
kv_lora_rank
=
kv_lora_rank
,
qk_nope_head_dim
=
qk_nope_head_dim
,
qk_nope_head_dim
=
qk_nope_head_dim
,
qk_rope_head_dim
=
qk_rope_head_dim
,
qk_rope_head_dim
=
qk_rope_head_dim
,
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
,
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
,
v_head_dim
=
v_head_dim
,
v_head_dim
=
v_head_dim
,
kv_b_proj
=
mock_kv_b_proj
,
kv_b_proj
=
mock_kv_b_proj
,
)
)
# Process weights to create W_UK_T and W_UV attributes needed by MLA
# Process weights to create W_UK_T and W_UV attributes needed by MLA
act_dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
act_dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
impl
.
process_weights_after_loading
(
act_dtype
)
impl
.
process_weights_after_loading
(
act_dtype
)
# Create mock layer and output buffer
# Create mock layer and output buffer
mock_layer
=
MockAttentionLayer
(
device
)
mock_layer
=
MockAttentionLayer
(
device
)
num_tokens
=
query
.
shape
[
0
]
num_tokens
=
query
.
shape
[
0
]
output
=
torch
.
empty
(
output
=
torch
.
empty
(
num_tokens
,
num_heads
*
v_head_dim
,
dtype
=
query
.
dtype
,
device
=
query
.
device
num_tokens
,
num_heads
*
v_head_dim
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
)
# Run forward pass
# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
# in the calling test function.
output
=
impl
.
forward
(
output
=
impl
.
forward
(
mock_layer
,
query
,
kv_c
,
k_pe
,
kv_cache
,
attn_metadata
,
output
=
output
mock_layer
,
query
,
kv_c
,
k_pe
,
kv_cache
,
attn_metadata
,
output
=
output
)
)
return
output
return
output
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -309,6 +330,8 @@ def run_attention_backend(
...
@@ -309,6 +330,8 @@ def run_attention_backend(
"large_prefill"
,
"large_prefill"
,
"single_decode"
,
"single_decode"
,
"single_prefill"
,
"single_prefill"
,
"spec_decode_small"
,
"spec_decode_medium"
,
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"deepseek-ai/DeepSeek-V2-Lite-Chat"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"deepseek-ai/DeepSeek-V2-Lite-Chat"
])
...
@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache.
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
"""
from
vllm.v1.attention.backends.mla.common
import
QueryLenSupport
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
is_spec_decode_test
=
batch_spec_name
.
startswith
(
"spec_decode"
)
spec_decode_backends
=
{
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
FLASHMLA
}
block_size
=
16
required_blocks
=
sum
(
(
seq_len
+
block_size
-
1
)
//
block_size
for
seq_len
in
batch_spec
.
seq_lens
)
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks
=
required_blocks
+
1
+
100
vllm_config
=
create_vllm_config
(
vllm_config
=
create_vllm_config
(
model_name
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
),
num_gpu_blocks
=
2048
model_name
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
),
num_gpu_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
)
)
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if
is_spec_decode_test
:
from
vllm.config
import
SpeculativeConfig
# Get the query length from the batch spec (they should all be uniform)
query_len
=
batch_spec
.
query_lens
[
0
]
# Set num_speculative_tokens to query_len - 1
# (since threshold is 1 + num_spec_tokens)
# Use ngram method which doesn't require a draft model
vllm_config
.
speculative_config
=
SpeculativeConfig
(
method
=
"ngram"
,
num_speculative_tokens
=
query_len
-
1
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
...
@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
k_pe_full
=
torch
.
randn
(
s_len
,
1
,
qk_rope_head_dim
,
dtype
=
dtype
,
device
=
device
)
k_pe_full
=
torch
.
randn
(
s_len
,
1
,
qk_rope_head_dim
,
dtype
=
dtype
,
device
=
device
)
# Determine if this is decode or prefill
# Determine if this sequence uses the decode pipeline or prefill
# pipeline for each backend
# NOTE: For spec decode tests with uniform query_len > 1, backends that
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
# uniform support) will use the decode pipeline (MQA-style), while
# backends that only support single-token queries will use the prefill
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode
=
[]
is_decode
=
[]
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
builder_cls
,
_
=
try_get_attention_backend
(
backend
)
builder_cls
,
_
=
try_get_attention_backend
(
backend
)
is_decode
.
append
(
q_len
<=
builder_cls
.
reorder_batch_threshold
)
if
is_spec_decode_test
:
query_len_support
=
getattr
(
builder_cls
,
"query_len_support"
,
QueryLenSupport
.
SINGLE_ONLY
)
supports_spec
=
query_len_support
!=
QueryLenSupport
.
SINGLE_ONLY
is_decode
.
append
(
supports_spec
)
else
:
threshold
=
getattr
(
builder_cls
,
"reorder_batch_threshold"
,
None
)
query_len_support
=
getattr
(
builder_cls
,
"query_len_support"
,
QueryLenSupport
.
SINGLE_ONLY
)
within_threshold
=
q_len
<=
threshold
if
threshold
else
False
if
(
within_threshold
and
query_len_support
==
QueryLenSupport
.
UNIFORM
and
i
>
0
):
first_q_len
=
query_lens
[
0
]
within_threshold
=
q_len
==
first_q_len
is_decode
.
append
(
within_threshold
)
# Split q into nope and rope components
# Split q into nope and rope components
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
...
@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
flatten
(
start_dim
=-
2
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
flatten
(
start_dim
=-
2
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
if
is_decode
[
i
]:
if
is_decode
[
backend_idx
]:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_decode
)
all_sdpa_outputs
[
backend_idx
].
append
(
sdpa_out_i_decode
)
else
:
else
:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_prefill
)
all_sdpa_outputs
[
backend_idx
].
append
(
sdpa_out_i_prefill
)
# Inputs for vLLM MLA backends are just the new tokens
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm
.
append
(
q_c
)
all_q_vllm
.
append
(
q_c
)
...
@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm
=
torch
.
cat
(
all_q_vllm
,
dim
=
0
)
query_vllm
=
torch
.
cat
(
all_q_vllm
,
dim
=
0
)
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
sdpa_outputs
=
[]
sdpa_outputs
=
{}
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
sdpa_outputs
.
append
(
torch
.
cat
(
all_sdpa_outputs
[
i
],
dim
=
0
)
)
sdpa_outputs
[
backend
]
=
torch
.
cat
(
all_sdpa_outputs
[
backend_idx
],
dim
=
0
)
# Create mock kv_b_proj using the same weights as reference implementation
# Create mock kv_b_proj using the same weights as reference implementation
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
...
@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
kv_lora_rank
,
num_q_heads
*
(
qk_nope_head_dim
+
v_head_dim
)
kv_lora_rank
,
num_q_heads
*
(
qk_nope_head_dim
+
v_head_dim
)
)
)
mock_kv_b_proj
.
weight
=
torch
.
nn
.
Parameter
(
kv_b_proj_weight
.
T
)
mock_kv_b_proj
.
weight
=
torch
.
nn
.
Parameter
(
kv_b_proj_weight
.
T
,
requires_grad
=
False
)
# Create metadata using original batch spec
# Create metadata using original batch spec
common_attn_metadata
=
create_common_attn_metadata
(
common_attn_metadata
=
create_common_attn_metadata
(
...
@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
)
# 4. Run vLLM backends and compare
# 4. Run vLLM backends and compare
for
i
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
):
# Skip backends that don't support spec decode for spec decode tests
if
is_spec_decode_test
and
backend_name
not
in
spec_decode_backends
:
continue
backend_output
=
run_attention_backend
(
backend_output
=
run_attention_backend
(
backend_name
,
backend_name
,
kv_cache_spec
,
kv_cache_spec
,
...
@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj
,
mock_kv_b_proj
,
)
)
# Use backend_idx to get the correct SDPA output for this backend
expected_output
=
sdpa_outputs
[
backend_name
]
# Check shape and dtype consistency
# Check shape and dtype consistency
assert
backend_output
.
shape
==
sdpa
_output
s
[
i
]
.
shape
,
(
assert
backend_output
.
shape
==
expected
_output
.
shape
,
(
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"SDPA shape
{
sdpa
_output
s
[
i
]
.
shape
}
"
f
"SDPA shape
{
expected
_output
.
shape
}
"
)
)
assert
backend_output
.
dtype
==
sdpa
_output
s
[
i
]
.
dtype
,
(
assert
backend_output
.
dtype
==
expected
_output
.
dtype
,
(
f
"[
{
backend_name
}
] dtype
{
backend_output
.
dtype
}
!= "
f
"[
{
backend_name
}
] dtype
{
backend_output
.
dtype
}
!= "
f
"SDPA dtype
{
sdpa
_output
s
[
i
]
.
dtype
}
"
f
"SDPA dtype
{
expected
_output
.
dtype
}
"
)
)
assert
torch
.
isfinite
(
backend_output
).
all
(),
(
assert
torch
.
isfinite
(
backend_output
).
all
(),
(
...
@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol
=
1e-2
rtol
=
1e-2
atol
=
5e-1
atol
=
5e-1
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa
_output
s
[
i
]
)).
item
()
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
expected
_output
)).
item
()
max_rel_diff
=
torch
.
max
(
max_rel_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa
_output
s
[
i
]
)
/
torch
.
abs
(
sdpa
_output
s
[
i
]
)
torch
.
abs
(
backend_output
-
expected
_output
)
/
torch
.
abs
(
expected
_output
)
).
item
()
).
item
()
all_close
=
torch
.
allclose
(
all_close
=
torch
.
allclose
(
backend_output
,
sdpa
_output
s
[
i
]
,
rtol
=
rtol
,
atol
=
atol
backend_output
,
expected
_output
,
rtol
=
rtol
,
atol
=
atol
)
)
assert
all_close
,
(
assert
all_close
,
(
...
...
vllm/v1/attention/backends/mla/common.py
View file @
82af928c
...
@@ -190,6 +190,7 @@ return curr_o @ W_O
...
@@ -190,6 +190,7 @@ return curr_o @ W_O
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
ClassVar
,
Generic
,
TypeVar
from
typing
import
ClassVar
,
Generic
,
TypeVar
import
torch
import
torch
...
@@ -227,6 +228,24 @@ from vllm.v1.attention.backends.utils import (
...
@@ -227,6 +228,24 @@ from vllm.v1.attention.backends.utils import (
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
class
QueryLenSupport
(
Enum
):
"""Defines the level of query length support for an attention backend's
decode pipeline.
- SINGLE_ONLY: Decode pipeline only supports single-token queries
(query_len=1)
- UNIFORM: Decode pipeline supports uniform multi-token queries
(all requests must have same query_len > 1)
- VARLEN: Decode pipeline supports variable-length queries
(mixed query lengths in same batch)
"""
SINGLE_ONLY
=
"single_only"
UNIFORM
=
"uniform"
VARLEN
=
"varlen"
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -460,19 +479,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -460,19 +479,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
understand this class
understand this class
"""
"""
# Whether the backend supports reordering the batch such that
# Defines the level of query length support for this backend.
# short sequences (i.e. verification for speculative decoding) are
# - SINGLE_ONLY: Only single-token queries (no spec decode support)
# classified as decode requests.
# - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
# If True, this will increase `reorder_batch_threshold` (below) when
# - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
# speculative decoding is enabled, and set `require_uniform=True` when
# If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
# when reordering the batch. Non-uniform decode requests will
# speculative decoding is enabled.
# fall back to prefill in this case.
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
SINGLE_ONLY
supports_uniform_spec_as_decode
:
ClassVar
[
bool
]
=
False
# The threshold for reordering the batch into decode and prefill requests.
# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
# If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests.
# query length <= threshold are classified as decode requests.
# Use `
supports_uniform_spec_as_decode
` (above) to set this automatically
# Use `
query_len_support
` (above) to set this automatically
# when speculative decoding is enabled.
# when speculative decoding is enabled.
reorder_batch_threshold
:
int
=
1
reorder_batch_threshold
:
int
=
1
...
@@ -599,11 +617,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -599,11 +617,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device
=
device
,
device
=
device
,
)
)
supports_spec_
as_
decode
=
self
.
supports_uniform_spec_as_decode
supports_spec_decode
=
self
.
query_len_support
!=
QueryLenSupport
.
SINGLE_ONLY
self
.
_init_reorder_batch_threshold
(
self
.
_init_reorder_batch_threshold
(
self
.
reorder_batch_threshold
,
supports_spec_
as_
decode
self
.
reorder_batch_threshold
,
supports_spec_decode
)
)
# Validate consistency between query_len_support and reorder_batch_threshold
if
self
.
query_len_support
==
QueryLenSupport
.
SINGLE_ONLY
:
assert
self
.
reorder_batch_threshold
==
1
,
(
f
"reorder_batch_threshold must be 1 when query_len_support is "
f
"SINGLE_ONLY, got
{
self
.
reorder_batch_threshold
}
"
)
def
_build_fi_prefill_wrappers
(
self
,
prefill
:
FlashInferPrefillMetadata
):
def
_build_fi_prefill_wrappers
(
self
,
prefill
:
FlashInferPrefillMetadata
):
qo_indptr
=
prefill
.
query_start_loc
qo_indptr
=
prefill
.
query_start_loc
...
@@ -745,7 +770,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -745,7 +770,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
split_decodes_and_prefills
(
split_decodes_and_prefills
(
common_attn_metadata
,
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
self
.
supports_uniform_spec_as_decode
,
require_uniform
=
(
self
.
query_len_support
!=
QueryLenSupport
.
VARLEN
)
,
)
)
)
)
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
82af928c
...
@@ -24,6 +24,7 @@ from vllm.v1.attention.backends.mla.common import (
...
@@ -24,6 +24,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -66,8 +67,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
...
@@ -66,8 +67,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
class
FlashAttnMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashAttnMLAMetadata
]):
class
FlashAttnMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashAttnMLAMetadata
]):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
VARLEN
reorder_batch_threshold
:
int
=
512
reorder_batch_threshold
:
int
=
512
# process small prefills with decode pathway
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
82af928c
...
@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.mla.common import (
...
@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
...
@@ -22,11 +23,8 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
...
@@ -22,11 +23,8 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class
FlashInferMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
MLACommonMetadata
]):
class
FlashInferMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
MLACommonMetadata
]):
# enable spec-as-decode optimization
supports_uniform_spec_as_decode
:
ClassVar
[
bool
]
=
True
# enable full CUDA Graph support for decode-only capture
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
UNIFORM
class
FlashInferMLABackend
(
MLACommonBackend
):
class
FlashInferMLABackend
(
MLACommonBackend
):
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
82af928c
...
@@ -20,8 +20,13 @@ from vllm.v1.attention.backends.mla.common import (
...
@@ -20,8 +20,13 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
reshape_attn_output_for_spec_decode
,
reshape_query_for_spec_decode
,
)
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -62,6 +67,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
...
@@ -62,6 +67,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
UNIFORM
reorder_batch_threshold
:
int
=
512
# process small prefills with decode pathway
# ^ TODO(matt): tune this
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -216,8 +224,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -216,8 +224,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
(
q
,
dim
=-
1
)
q
=
torch
.
cat
(
q
,
dim
=-
1
)
assert
isinstance
(
q
,
torch
.
Tensor
)
assert
isinstance
(
q
,
torch
.
Tensor
)
num_decodes
=
attn_metadata
.
num_decodes
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
o
,
lse
=
flash_mla_with_kvcache
(
o
,
lse
=
flash_mla_with_kvcache
(
q
=
q
.
unsqueeze
(
1
),
# Add seqlen dim of 1 (decode)
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
...
@@ -230,4 +242,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -230,4 +242,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
)
)
o
=
reshape_attn_output_for_spec_decode
(
o
)
return
o
,
lse
return
o
,
lse
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment