Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f24b2de3
Unverified
Commit
f24b2de3
authored
Feb 20, 2026
by
Wei Zhao
Committed by
GitHub
Feb 20, 2026
Browse files
[Test] Add FP8 KV Cache Testing for MLA Backends (#34473)
Signed-off-by:
wzhao18
<
wzhao18.sz@gmail.com
>
parent
fac1507f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
27 deletions
+68
-27
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+68
-27
No files found.
tests/v1/attention/test_mla_backends.py
View file @
f24b2de3
...
...
@@ -19,8 +19,13 @@ from tests.v1.attention.utils import (
)
from
vllm
import
_custom_ops
as
ops
from
vllm.config.vllm
import
set_current_vllm_config
from
vllm.model_executor.layers.attention.mla_attention
import
QueryLenSupport
from
vllm.model_executor.layers.attention.mla_attention
import
(
QueryLenSupport
,
_DecodeConcatQuantFP8
,
)
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
...
...
@@ -50,6 +55,7 @@ if not flash_attn_supports_mla():
if
not
is_flashmla_dense_supported
()[
0
]:
BACKENDS_TO_TEST
.
remove
(
AttentionBackendEnum
.
FLASHMLA
)
SPEC_DECODE_BACKENDS
=
[]
for
backend
in
BACKENDS_TO_TEST
:
builder_cls
,
_
=
try_get_attention_backend
(
backend
)
...
...
@@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype,
the cache is populated via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.
...
...
@@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache(
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
fp8_attention
=
kv_cache_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
)
use_fp8_ds_mla
=
kv_cache_dtype
==
"fp8_ds_mla"
if
fp8_attention
:
if
use_fp8_ds_mla
:
if
not
kv_c_contexts
:
raise
ValueError
(
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
)
kv_lora_rank
=
kv_c_contexts
[
0
].
shape
[
-
1
]
rope_dim
=
k_pe_contexts
[
0
].
shape
[
-
1
]
entry_size
=
kv_lora_rank
+
4
*
4
+
2
*
rope_dim
# 4 * 4: 4 float32 scale values for 128-element tiles
# 2 * rope_dim: 16-bit RoPE values
kv_entry_size
=
kv_lora_rank
+
4
*
4
+
2
*
rope_dim
else
:
kv_entry_size
=
head_size
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
entry_size
,
dtype
=
torch
.
uint8
,
device
=
device
num_blocks
,
block_size
,
kv_
entry_size
,
dtype
=
torch
.
uint8
,
device
=
device
)
scale_tensor
=
(
scale
...
...
@@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache(
start
=
start_block_idx
*
block_size
if
use_fp8_ds_mla
:
if
fp8_attention
:
slots
=
torch
.
arange
(
context_len
,
device
=
device
,
dtype
=
torch
.
long
)
+
start
ops
.
concat_and_cache_mla
(
kv_c_context
,
k_pe_context
.
squeeze
(
1
),
kv_cache
,
slots
,
kv_cache_dtype
=
"fp8_ds_mla"
,
kv_cache_dtype
=
kv_cache_dtype
,
scale
=
scale_tensor
,
)
else
:
...
...
@@ -329,8 +337,9 @@ class MockSparseMLAAttentionLayer:
output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype
=
getattr
(
self
.
impl
,
"kv_cache_dtype"
,
"auto"
)
# Write to KV cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
kv_c
,
...
...
@@ -426,6 +435,12 @@ class MockMLAAttentionLayer(AttentionLayerBase):
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
def
get_attn_backend
(
self
):
raise
NotImplementedError
...
...
@@ -443,16 +458,21 @@ class MockMLAAttentionLayer(AttentionLayerBase):
)
->
torch
.
Tensor
:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
kv_cache_dtype
=
getattr
(
self
.
impl
,
"kv_cache_dtype"
,
"auto"
)
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
kv_c
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
kv_cache_dtype
,
scale
=
self
.
_k_scale
,
)
if
fp8_attention
and
kv_cache_dtype
!=
"fp8_ds_mla"
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
# Determine decode vs prefill split
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
or
0
has_decode
=
(
attn_metadata
.
num_decodes
or
0
)
>
0
...
...
@@ -491,7 +511,13 @@ class MockMLAAttentionLayer(AttentionLayerBase):
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope
=
mqa_ql_nope
.
transpose
(
0
,
1
)
# Pass as tuple to forward_mqa
if
fp8_attention
and
self
.
impl
.
supports_quant_query_input
:
assert
mqa_ql_nope
.
shape
[
0
]
==
mqa_q_pe
.
shape
[
0
]
assert
mqa_ql_nope
.
shape
[
1
]
==
mqa_q_pe
.
shape
[
1
]
mqa_q
=
self
.
_decode_concat_quant_fp8_op
(
mqa_ql_nope
,
mqa_q_pe
,
self
.
_q_scale
)
else
:
mqa_q
=
(
mqa_ql_nope
,
mqa_q_pe
)
attn_out
,
_
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
...
...
@@ -526,6 +552,7 @@ def run_attention_backend(
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
mock_kv_b_proj
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
...
...
@@ -550,7 +577,7 @@ def run_attention_backend(
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
kv_cache_dtype
,
logits_soft_cap
=
None
,
attn_type
=
"decoder"
,
kv_sharing_target_layer_name
=
None
,
...
...
@@ -630,12 +657,14 @@ def run_attention_backend(
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"deepseek-ai/DeepSeek-R1"
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
,
"fp8_e4m3"
])
def
test_backend_correctness
(
default_vllm_config
,
dist_init
,
batch_spec_name
:
str
,
model
:
str
,
tensor_parallel_size
:
int
,
kv_cache_dtype
:
str
,
):
"""
Test that all backends produce similar outputs to a reference implementation
...
...
@@ -658,9 +687,18 @@ def test_backend_correctness(
head counts.
"""
# Filter backends to those that support the requested kv_cache_dtype
backends_to_test
=
[
b
for
b
in
BACKENDS_TO_TEST
if
kv_cache_dtype
in
b
.
get_class
().
supported_kv_cache_dtypes
]
if
not
backends_to_test
:
pytest
.
skip
(
f
"No backends support kv_cache_dtype=
{
kv_cache_dtype
}
"
)
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
is_spec_decode_test
=
batch_spec_name
.
startswith
(
"spec_decode"
)
unique_block_sizes
=
sorted
(
set
(
BACKEND_BLOCK_SIZES
.
values
()
))
unique_block_sizes
=
sorted
(
set
(
BACKEND_BLOCK_SIZES
[
b
]
for
b
in
backends_to_test
))
default_block_size
=
unique_block_sizes
[
0
]
required_blocks
=
sum
(
(
seq_len
+
default_block_size
-
1
)
//
default_block_size
...
...
@@ -694,6 +732,7 @@ def test_backend_correctness(
block_size
=
default_block_size
,
hf_config_override
=
hf_config_override
,
)
vllm_config
.
cache_config
.
cache_dtype
=
kv_cache_dtype
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if
is_spec_decode_test
:
...
...
@@ -751,7 +790,7 @@ def test_backend_correctness(
kv_b_proj_weight
=
torch
.
cat
([
W_UK
,
W_UV
],
dim
=-
1
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
i
,
backend
in
enumerate
(
backends_to_test
):
all_sdpa_outputs
.
append
([])
for
i
in
range
(
batch_size
):
...
...
@@ -785,7 +824,7 @@ def test_backend_correctness(
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode
=
[]
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
backends_to_test
):
builder_cls
,
_
=
try_get_attention_backend
(
backend
)
if
is_spec_decode_test
:
query_len_support
=
getattr
(
...
...
@@ -885,7 +924,7 @@ def test_backend_correctness(
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
flatten
(
start_dim
=-
2
)
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
backends_to_test
):
if
is_decode
[
backend_idx
]:
all_sdpa_outputs
[
backend_idx
].
append
(
sdpa_out_i_decode
)
else
:
...
...
@@ -905,7 +944,7 @@ def test_backend_correctness(
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
sdpa_outputs
=
{}
for
backend_idx
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend
in
enumerate
(
backends_to_test
):
sdpa_outputs
[
backend
]
=
torch
.
cat
(
all_sdpa_outputs
[
backend_idx
],
dim
=
0
)
# Create mock kv_b_proj using the same weights as reference implementation
...
...
@@ -973,12 +1012,13 @@ def test_backend_correctness(
num_blocks
=
num_blocks_for_size
,
common_attn_metadata
=
common_attn_metadata
,
randomize_blocks
=
True
,
kv_cache_dtype
=
kv_cache_dtype
,
)
kv_cache_per_block_size
[
block_size
]
=
kv_cache
# 4. Run vLLM backends and compare
failures
=
[]
for
backend_idx
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend_name
in
enumerate
(
backends_to_test
):
# Skip backends that don't support spec decode for spec decode tests
if
is_spec_decode_test
and
backend_name
not
in
SPEC_DECODE_BACKENDS
:
continue
...
...
@@ -997,7 +1037,7 @@ def test_backend_correctness(
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
(),
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
cache_dtype_str
=
kv_
cache_dtype
,
)
backend_output
=
run_attention_backend
(
...
...
@@ -1016,6 +1056,7 @@ def test_backend_correctness(
qk_rope_head_dim
,
v_head_dim
,
mock_kv_b_proj
,
kv_cache_dtype
=
kv_cache_dtype
,
)
# Use backend_idx to get the correct SDPA output for this backend
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment