Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
558f0907
Unverified
Commit
558f0907
authored
Sep 07, 2025
by
youkaichao
Committed by
GitHub
Sep 07, 2025
Browse files
[attention][DCP] use AttentionImpl.need_to_return_lse_for_decode (#24372)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
4172235a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
9 deletions
+38
-9
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+26
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+0
-4
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-5
No files found.
vllm/attention/backends/abstract.py
View file @
558f0907
...
@@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
...
@@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode
:
bool
=
False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode
:
bool
=
False
dcp_world_size
:
int
dcp_rank
:
int
def
__new__
(
cls
,
*
args
,
**
kwargs
):
# use __new__ so that all subclasses will call this
self
=
super
().
__new__
(
cls
)
try
:
from
vllm.distributed.parallel_state
import
get_dcp_group
self
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
except
AssertionError
:
# DCP might not be initialized in testing
self
.
dcp_world_size
=
1
self
.
dcp_rank
=
0
self
.
need_to_return_lse_for_decode
=
self
.
dcp_world_size
>
1
\
and
self
.
can_return_lse_for_decode
return
self
@
abstractmethod
@
abstractmethod
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
558f0907
...
@@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# recorect dcp attn_out with lse.
# recorect dcp attn_out with lse.
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
assert
lse
is
not
None
,
(
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse."
)
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
())
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
())
# v_up projection
# v_up projection
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
558f0907
...
@@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
class
FlashMLAImpl
(
MLACommonImpl
[
FlashMLAMetadata
]):
class
FlashMLAImpl
(
MLACommonImpl
[
FlashMLAMetadata
]):
can_return_lse_for_decode
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
num_heads
:
int
,
num_heads
:
int
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
558f0907
...
@@ -56,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -56,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
round_up
,
get_dtype_size
,
is_pin_memory_available
,
round_up
,
supports_dynamo
)
supports_dynamo
)
from
vllm.v1.attention.backends.mla.flashmla
import
FlashMLABackend
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
create_fast_prefill_custom_backend
,
create_fast_prefill_custom_backend
,
...
@@ -3405,10 +3404,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3405,10 +3404,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
copy_kv_blocks
)
copy_kv_blocks
)
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
assert
self
.
attn_groups
[
0
][
0
].
backend
is
FlashMLABackend
,
(
layer_names
=
self
.
attn_groups
[
0
][
0
].
layer_names
"DCP only support flashmla now."
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
"For a mla backend want to enable DCP, it is mandatory that the"
AttentionLayerBase
,
"corresponding decode attn kernel return the softmax lse."
)
layer_names
)
for
layer
in
layers
.
values
():
assert
layer
.
impl
.
need_to_return_lse_for_decode
,
(
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f
"
{
layer
.
impl
.
__class__
.
__name__
}
"
"does not return the softmax lse for decode."
)
def
may_add_encoder_only_layers_to_kv_cache_config
(
self
)
->
None
:
def
may_add_encoder_only_layers_to_kv_cache_config
(
self
)
->
None
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment