Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e8ee2a78
Unverified
Commit
e8ee2a78
authored
Apr 24, 2026
by
Jiangyun Zhu
Committed by
GitHub
Apr 24, 2026
Browse files
[Attention] use diff kv backend for mimo v2 flash (#40045)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
2ec18f5d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
112 additions
and
24 deletions
+112
-24
docs/design/attention_backends.md
docs/design/attention_backends.md
+1
-1
tools/pre_commit/generate_attention_backend_docs.py
tools/pre_commit/generate_attention_backend_docs.py
+41
-8
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+1
-0
vllm/model_executor/models/mimo_v2_flash.py
vllm/model_executor/models/mimo_v2_flash.py
+14
-8
vllm/v1/attention/backends/fa_utils.py
vllm/v1/attention/backends/fa_utils.py
+22
-3
vllm/v1/attention/backends/flash_attn_diffkv.py
vllm/v1/attention/backends/flash_attn_diffkv.py
+18
-4
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+14
-0
vllm/vllm_flash_attn/flash_attn_interface.py
vllm/vllm_flash_attn/flash_attn_interface.py
+1
-0
No files found.
docs/design/attention_backends.md
View file @
e8ee2a78
...
@@ -172,7 +172,7 @@ Priority is **1 = highest** (tried first).
...
@@ -172,7 +172,7 @@ Priority is **1 = highest** (tried first).
|
`FLASHINFER`
| TRTLLM† | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
`FLASHINFER`
| TRTLLM† | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
`FLASH_ATTN`
| FA2
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
`FLASH_ATTN`
| FA2
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
`FLASH_ATTN`
| FA3
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
`FLASH_ATTN`
| FA3
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
`FLASH_ATTN`
| FA4
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | Any |
❌
| ❌ | ✅ | All | ≥10.0 |
|
`FLASH_ATTN`
| FA4
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | Any |
✅
| ❌ | ✅ | All | ≥10.0 |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
| Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
| Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
...
...
tools/pre_commit/generate_attention_backend_docs.py
View file @
e8ee2a78
...
@@ -634,9 +634,10 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
...
@@ -634,9 +634,10 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
except
Exception
:
except
Exception
:
return
{}
return
{}
# Analyze the functions to determine FA3-specific features
# Analyze the functions to determine FA3
/FA4
-specific features
fa3_supports_fp8
=
False
fa3_supports_fp8
=
False
fa3_supports_sinks
=
False
fa3_supports_sinks
=
False
fa4_supports_sinks
=
False
fa3_compute_cap
:
str
|
None
=
None
fa3_compute_cap
:
str
|
None
=
None
fa4_compute_cap
:
str
|
None
=
None
fa4_compute_cap
:
str
|
None
=
None
...
@@ -656,17 +657,49 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
...
@@ -656,17 +657,49 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8
=
True
fa3_supports_fp8
=
True
break
break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
# Check flash_attn_supports_sinks - looks for `fa_version == 3/4`
# or `get_flash_attn_version() == 3/4` (also accepts `in (3, 4)`)
if
node
.
name
==
"flash_attn_supports_sinks"
:
if
node
.
name
==
"flash_attn_supports_sinks"
:
for
n
in
ast
.
walk
(
node
):
for
n
in
ast
.
walk
(
node
):
if
(
if
(
isinstance
(
n
,
ast
.
Compare
)
isinstance
(
n
,
ast
.
Compare
)
and
isinstance
(
n
.
left
,
ast
.
Call
)
and
len
(
n
.
ops
)
==
1
and
isinstance
(
n
.
ops
[
0
],
ast
.
Eq
)
and
isinstance
(
n
.
comparators
[
0
],
ast
.
Constant
)
):
is_version_compare
=
(
isinstance
(
n
.
left
,
ast
.
Name
)
and
n
.
left
.
id
==
"fa_version"
)
or
(
isinstance
(
n
.
left
,
ast
.
Call
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
)
if
is_version_compare
:
val
=
n
.
comparators
[
0
].
value
if
val
==
3
:
fa3_supports_sinks
=
True
elif
val
==
4
:
fa4_supports_sinks
=
True
elif
(
isinstance
(
n
,
ast
.
Compare
)
and
len
(
n
.
ops
)
==
1
and
isinstance
(
n
.
ops
[
0
],
ast
.
In
)
and
isinstance
(
n
.
comparators
[
0
],
(
ast
.
Tuple
,
ast
.
List
,
ast
.
Set
))
):
):
is_version_compare
=
(
isinstance
(
n
.
left
,
ast
.
Name
)
and
n
.
left
.
id
==
"fa_version"
)
or
(
isinstance
(
n
.
left
,
ast
.
Call
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
)
if
is_version_compare
:
for
elt
in
n
.
comparators
[
0
].
elts
:
if
isinstance
(
elt
,
ast
.
Constant
):
if
elt
.
value
==
3
:
fa3_supports_sinks
=
True
fa3_supports_sinks
=
True
break
elif
elt
.
value
==
4
:
fa4_supports_sinks
=
True
# Check get_flash_attn_version for FA3/FA4 compute capability
# Check get_flash_attn_version for FA3/FA4 compute capability
if
node
.
name
==
"get_flash_attn_version"
:
if
node
.
name
==
"get_flash_attn_version"
:
...
@@ -731,7 +764,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
...
@@ -731,7 +764,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"fa4"
:
{
"fa4"
:
{
"compute_capability"
:
fa4_compute_cap
,
"compute_capability"
:
fa4_compute_cap
,
"supports_fp8"
:
False
,
"supports_fp8"
:
False
,
"supports_sink"
:
False
,
"supports_sink"
:
fa4_supports_sinks
,
},
},
}
}
...
...
vllm/model_executor/layers/attention/attention.py
View file @
e8ee2a78
...
@@ -597,6 +597,7 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -597,6 +597,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
head_size_v
=
self
.
head_size_v
,
dtype
=
self
.
kv_cache_torch_dtype
,
dtype
=
self
.
kv_cache_torch_dtype
,
kv_quant_mode
=
quant_mode
,
kv_quant_mode
=
quant_mode
,
sliding_window
=
self
.
sliding_window
,
sliding_window
=
self
.
sliding_window
,
...
...
vllm/model_executor/models/mimo_v2_flash.py
View file @
e8ee2a78
...
@@ -46,6 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -46,6 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backends.flash_attn_diffkv
import
(
FlashAttentionDiffKVBackend
,
)
from
.interfaces
import
MixtureOfExperts
,
SupportsPP
from
.interfaces
import
MixtureOfExperts
,
SupportsPP
from
.utils
import
(
from
.utils
import
(
...
@@ -287,6 +290,15 @@ class MiMoV2Attention(nn.Module):
...
@@ -287,6 +290,15 @@ class MiMoV2Attention(nn.Module):
)
)
sliding_window
=
sliding_window_size
if
sliding_window_size
>
-
1
else
None
sliding_window
=
sliding_window_size
if
sliding_window_size
>
-
1
else
None
# Use DiffKV backend when V has a different head dim than K
if
self
.
v_head_dim
!=
self
.
head_dim
:
FlashAttentionDiffKVBackend
.
set_head_size_v
(
self
.
v_head_dim
)
attn_backend
=
FlashAttentionDiffKVBackend
logger
.
info_once
(
"Using FlashAttentionDiffKVBackend for attention."
)
else
:
attn_backend
=
None
self
.
attn
=
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -298,6 +310,8 @@ class MiMoV2Attention(nn.Module):
...
@@ -298,6 +310,8 @@ class MiMoV2Attention(nn.Module):
attn_type
=
AttentionType
.
DECODER
,
attn_type
=
AttentionType
.
DECODER
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
sinks
=
self
.
attention_sink_bias
,
sinks
=
self
.
attention_sink_bias
,
attn_backend
=
attn_backend
,
head_size_v
=
self
.
v_head_dim
,
)
)
def
forward
(
def
forward
(
...
@@ -313,16 +327,8 @@ class MiMoV2Attention(nn.Module):
...
@@ -313,16 +327,8 @@ class MiMoV2Attention(nn.Module):
if
self
.
v_scale
is
not
None
:
if
self
.
v_scale
is
not
None
:
v
=
v
*
self
.
v_scale
v
=
v
*
self
.
v_scale
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
v_head_dim
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
self
.
head_dim
-
self
.
v_head_dim
],
value
=
0
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
*
self
.
head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)[
...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/v1/attention/backends/fa_utils.py
View file @
e8ee2a78
...
@@ -54,7 +54,10 @@ elif current_platform.is_rocm():
...
@@ -54,7 +54,10 @@ elif current_platform.is_rocm():
def
get_flash_attn_version
(
def
get_flash_attn_version
(
requires_alibi
:
bool
=
False
,
head_size
:
int
|
None
=
None
requires_alibi
:
bool
=
False
,
head_size
:
int
|
None
=
None
,
head_size_v
:
int
|
None
=
None
,
has_sinks
:
bool
=
False
,
)
->
int
|
None
:
)
->
int
|
None
:
if
current_platform
.
is_xpu
():
if
current_platform
.
is_xpu
():
return
2
return
2
...
@@ -112,6 +115,23 @@ def get_flash_attn_version(
...
@@ -112,6 +115,23 @@ def get_flash_attn_version(
)
)
fa_version
=
2
fa_version
=
2
# The FA3 kernel rejects s_aux (sinks) when hdim != hdim_v; upgrade to
# FA4 on SM90 when available.
if
(
fa_version
==
3
and
has_sinks
and
head_size
is
not
None
and
head_size_v
is
not
None
and
head_size
!=
head_size_v
and
device_capability
.
major
==
9
and
is_fa_version_supported
(
4
)
):
logger
.
info_once
(
"Diff-KV with sinks: upgrading FlashAttention 3 -> 4"
,
scope
=
"local"
,
)
fa_version
=
4
# FA4 currently uses batch-shape-dependent scheduling
# FA4 currently uses batch-shape-dependent scheduling
# heuristics on SM100+, which breaks batch invariance.
# heuristics on SM100+, which breaks batch invariance.
if
envs
.
VLLM_BATCH_INVARIANT
and
fa_version
==
4
:
if
envs
.
VLLM_BATCH_INVARIANT
and
fa_version
==
4
:
...
@@ -180,8 +200,7 @@ def flash_attn_supports_quant_query_input() -> bool:
...
@@ -180,8 +200,7 @@ def flash_attn_supports_quant_query_input() -> bool:
def
flash_attn_supports_sinks
()
->
bool
:
def
flash_attn_supports_sinks
()
->
bool
:
if
current_platform
.
is_xpu
():
if
current_platform
.
is_xpu
():
return
True
return
True
else
:
return
get_flash_attn_version
()
in
(
3
,
4
)
return
get_flash_attn_version
()
==
3
def
flash_attn_supports_mla
():
def
flash_attn_supports_mla
():
...
...
vllm/v1/attention/backends/flash_attn_diffkv.py
View file @
e8ee2a78
...
@@ -6,14 +6,16 @@ import torch
...
@@ -6,14 +6,16 @@ import torch
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backends.fa_utils
import
is_flash_attn_varlen_func_available
from
vllm.v1.attention.backends.fa_utils
import
(
get_flash_attn_version
,
is_flash_attn_varlen_func_available
,
)
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
triton_reshape_and_cache_flash_diffkv
,
triton_reshape_and_cache_flash_diffkv
,
)
)
if
is_flash_attn_varlen_func_available
():
if
is_flash_attn_varlen_func_available
():
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_varlen_func
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_varlen_func
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
.flash_attn
import
(
from
.flash_attn
import
(
...
@@ -23,8 +25,6 @@ from .flash_attn import (
...
@@ -23,8 +25,6 @@ from .flash_attn import (
cascade_attention
,
cascade_attention
,
)
)
logger
=
init_logger
(
__name__
)
class
FlashAttentionDiffKVBackend
(
FlashAttentionBackend
):
class
FlashAttentionDiffKVBackend
(
FlashAttentionBackend
):
# Default to 128 for this backend
# Default to 128 for this backend
...
@@ -86,6 +86,20 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend):
...
@@ -86,6 +86,20 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend):
class
FlashAttentionDiffKVImpl
(
FlashAttentionImpl
):
class
FlashAttentionDiffKVImpl
(
FlashAttentionImpl
):
vllm_flash_attn_version
:
int
|
None
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# Re-derive the FA version with diff-kv context so that
# get_flash_attn_version can apply the FA3 -> FA4 upgrade rule
# for sinks + hdim != hdim_v.
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
,
head_size
=
self
.
head_size
,
head_size_v
=
FlashAttentionDiffKVBackend
.
head_size_v
,
has_sinks
=
self
.
sinks
is
not
None
,
)
def
do_kv_cache_update
(
def
do_kv_cache_update
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/v1/kv_cache_interface.py
View file @
e8ee2a78
...
@@ -356,6 +356,20 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
...
@@ -356,6 +356,20 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
SlidingWindowSpec
(
AttentionSpec
):
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
sliding_window
:
int
head_size_v
:
int
=
None
# type: ignore[assignment]
def
__post_init__
(
self
):
if
self
.
head_size_v
is
None
:
object
.
__setattr__
(
self
,
"head_size_v"
,
self
.
head_size
)
@
property
def
real_page_size_bytes
(
self
)
->
int
:
return
(
self
.
block_size
*
self
.
num_kv_heads
*
(
self
.
head_size
+
self
.
head_size_v
)
*
get_dtype_size
(
self
.
dtype
)
)
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
assert
vllm_config
.
parallel_config
.
decode_context_parallel_size
==
1
,
(
assert
vllm_config
.
parallel_config
.
decode_context_parallel_size
==
1
,
(
...
...
vllm/vllm_flash_attn/flash_attn_interface.py
View file @
e8ee2a78
...
@@ -387,6 +387,7 @@ def flash_attn_varlen_func(
...
@@ -387,6 +387,7 @@ def flash_attn_varlen_func(
num_splits
=
num_splits
,
num_splits
=
num_splits
,
return_lse
=
return_softmax_lse
,
return_lse
=
return_softmax_lse
,
out
=
out
,
out
=
out
,
learnable_sink
=
s_aux
,
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported FA version:
{
fa_version
}
"
)
raise
ValueError
(
f
"Unsupported FA version:
{
fa_version
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment