Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
379689d5
Unverified
Commit
379689d5
authored
Mar 07, 2026
by
Wei Zhao
Committed by
GitHub
Mar 07, 2026
Browse files
[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
parent
a6be75db
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
17 deletions
+89
-17
docs/design/attention_backends.md
docs/design/attention_backends.md
+1
-1
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+18
-2
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+11
-1
tools/pre_commit/generate_attention_backend_docs.py
tools/pre_commit/generate_attention_backend_docs.py
+15
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+30
-5
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+0
-7
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
+7
-0
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+7
-0
No files found.
docs/design/attention_backends.md
View file @
379689d5
...
...
@@ -206,7 +206,7 @@ configuration.
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
|
`CUTLASS_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
|
`FLASHINFER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
`FLASHINFER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
`FLASHINFER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
`FLASHMLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
...
...
tests/v1/attention/test_mla_backends.py
View file @
379689d5
...
...
@@ -327,6 +327,12 @@ class MockSparseMLAAttentionLayer:
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
def
forward_impl
(
self
,
q
:
torch
.
Tensor
,
...
...
@@ -338,6 +344,7 @@ class MockSparseMLAAttentionLayer:
)
->
torch
.
Tensor
:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
kv_cache_dtype
=
getattr
(
self
.
impl
,
"kv_cache_dtype"
,
"auto"
)
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
# Write to KV cache
if
kv_cache
.
numel
()
>
0
:
...
...
@@ -350,6 +357,9 @@ class MockSparseMLAAttentionLayer:
scale
=
self
.
_k_scale
,
)
if
fp8_attention
and
kv_cache_dtype
!=
"fp8_ds_mla"
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
num_tokens
=
q
.
shape
[
0
]
# Sparse MLA uses forward_mqa for all tokens
...
...
@@ -367,7 +377,13 @@ class MockSparseMLAAttentionLayer:
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope
=
mqa_ql_nope
.
transpose
(
0
,
1
)
# Pass as tuple to forward_mqa
if
fp8_attention
and
self
.
impl
.
supports_quant_query_input
:
assert
mqa_ql_nope
.
shape
[
0
]
==
mqa_q_pe
.
shape
[
0
]
assert
mqa_ql_nope
.
shape
[
1
]
==
mqa_q_pe
.
shape
[
1
]
mqa_q
=
self
.
_decode_concat_quant_fp8_op
(
mqa_ql_nope
,
mqa_q_pe
,
self
.
_q_scale
)
else
:
mqa_q
=
(
mqa_ql_nope
,
mqa_q_pe
)
attn_out
,
_
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
379689d5
...
...
@@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness(
if
kv_cache_dtype
not
in
backend_cls
.
supported_kv_cache_dtypes
:
pytest
.
skip
(
f
"
{
backend_cls
.
get_name
()
}
does not support
{
kv_cache_dtype
}
"
)
if
(
backend_cls
==
FlashMLASparseBackend
and
kv_cache_dtype
.
startswith
(
"fp8"
)
and
kv_cache_dtype
!=
"fp8_ds_mla"
):
pytest
.
skip
(
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
supported_block_sizes
=
backend_cls
.
get_supported_kernel_block_sizes
()
if
block_size
not
in
supported_block_sizes
:
pytest
.
skip
(
...
...
@@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness(
num_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
,
common_attn_metadata
=
common_attn_metadata
,
randomize_blocks
=
False
,
kv_cache_dtype
=
kv_cache_dtype
if
use_fp8_ds_mla_quantization
else
"auto"
,
kv_cache_dtype
=
kv_cache_dtype
,
scale
=
kv_cache_scale
,
)
...
...
tools/pre_commit/generate_attention_backend_docs.py
View file @
379689d5
...
...
@@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = (
# Backends to skip during doc generation
SKIP_BACKENDS
=
{
"CUSTOM"
,
"TORCH_SDPA"
}
BACKEND_KV_DTYPE_EXCLUDES
:
dict
[
str
,
set
[
str
]]
=
{
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
"FLASHMLA_SPARSE"
:
{
"fp8"
},
}
def
is_relevant_file
(
filepath
:
str
)
->
bool
:
"""Check if a file matches any of the relevant patterns."""
...
...
@@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
tree
,
impl_class_name
,
"can_return_lse_for_decode"
,
False
,
file_path
)
kv_cache_dtypes
=
parse_kv_cache_dtypes
(
class_node
)
if
backend_name
in
BACKEND_KV_DTYPE_EXCLUDES
:
excluded
=
BACKEND_KV_DTYPE_EXCLUDES
[
backend_name
]
kv_cache_dtypes
=
", "
.
join
(
d
for
d
in
(
d
.
strip
()
for
d
in
kv_cache_dtypes
.
split
(
","
))
if
d
not
in
excluded
)
return
{
"name"
:
backend_name
,
"dtypes"
:
parse_supported_dtypes
(
class_node
),
"kv_cache_dtypes"
:
parse_
kv_cache_dtypes
(
class_node
)
,
"kv_cache_dtypes"
:
kv_cache_dtypes
,
"block_sizes"
:
parse_block_sizes
(
class_node
),
"head_sizes"
:
parse_head_sizes
(
class_node
),
"attn_types"
:
parse_attention_types
(
class_node
),
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
379689d5
...
...
@@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
calculate_kv_scales
=
False
self
.
quant_config
=
quant_config
# Initialize KV cache quantization attributes
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
)
dtype
=
torch
.
get_default_dtype
()
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
...
...
@@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_heads
=
self
.
num_heads
,
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if
(
self
.
attn_backend
.
get_name
()
==
"FLASHMLA_SPARSE"
and
kv_cache_dtype
.
startswith
(
"fp8"
)
and
kv_cache_dtype
!=
"fp8_ds_mla"
):
assert
cache_config
is
not
None
cache_config
.
cache_dtype
=
"fp8_ds_mla"
kv_cache_dtype
=
"fp8_ds_mla"
logger
.
info_once
(
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
"fp8 kv-cache format, please set `--attention-backend "
"FLASHINFER_MLA_SPARSE`"
)
if
(
self
.
attn_backend
.
get_name
()
==
"FLASHINFER_MLA_SPARSE"
and
kv_cache_dtype
.
startswith
(
"fp8"
)
):
logger
.
info_once
(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
)
# Initialize KV cache quantization attributes
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
...
...
vllm/model_executor/models/config.py
View file @
379689d5
...
...
@@ -31,20 +31,13 @@ class VerifyAndUpdateConfig:
class
DeepseekV32ForCausalLM
(
VerifyAndUpdateConfig
):
@
classmethod
def
verify_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config
=
vllm_config
.
model_config
.
hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32
=
hasattr
(
hf_config
,
"index_topk"
)
assert
is_v32
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config
=
vllm_config
.
cache_config
if
cache_config
.
cache_dtype
.
startswith
(
"fp8"
):
cache_config
.
cache_dtype
=
"fp8_ds_mla"
logger
.
info
(
"Using custom fp8 kv-cache format for DeepSeekV3.2"
)
if
cache_config
.
cache_dtype
==
"bfloat16"
:
cache_config
.
cache_dtype
=
"auto"
logger
.
info
(
"Using bfloat16 kv-cache for DeepSeekV3.2"
)
...
...
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
View file @
379689d5
...
...
@@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
"fp8"
,
"fp8_e4m3"
,
]
@
staticmethod
...
...
@@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
self
.
bmm1_scale
:
float
|
None
=
None
self
.
bmm2_scale
:
float
|
None
=
None
# fp8 query quantization is required when using fp8 kv_cache,
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
# for query and kv_cache (mixed bf16+fp8 is not supported).
self
.
supports_quant_query_input
=
True
def
forward_mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
379689d5
...
...
@@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
"auto"
,
"bfloat16"
,
"fp8_ds_mla"
,
"fp8"
,
# alias for fp8_ds_mla
]
@
staticmethod
...
...
@@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
self
.
fp8_decode_padded_heads
=
self
.
_compute_fp8_decode_padded_heads
(
num_heads
)
if
kv_cache_dtype
.
startswith
(
"fp8"
):
assert
kv_cache_dtype
==
"fp8_ds_mla"
,
(
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
if
kv_cache_dtype
==
"fp8_ds_mla"
:
# Reserve workspace during initialization
vllm_config
=
get_current_vllm_config
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment