Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bcb2ba6c
Commit
bcb2ba6c
authored
Apr 08, 2026
by
xiabo
Committed by
zhangzbb
Apr 08, 2026
Browse files
[FEATURE] DuSwiftConnector support glm5 model PD (attention sparse_attn_indexer layer_name change )
parent
a05d749e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+3
-3
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+9
-5
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
bcb2ba6c
...
@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request_id (str): request id for log
request_id (str): request id for log
"""
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
dst_kv_cache_layer
.
ndim
==
3
:
num_pages
=
dst_kv_cache_layer_shape
[
0
]
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
...
@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
assert
self
.
du_swift_engine
is
not
None
assert
self
.
du_swift_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
kv_layer
.
ndim
==
3
def
extract_kv_from_layer
(
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
...
@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
layer
.
ndim
==
3
:
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
...]
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
bcb2ba6c
...
@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
...
@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
lightop
import
op
,
gemmopt
from
lightop
import
op
,
gemmopt
from
vllm.attention.utils.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
...
@@ -27,10 +31,10 @@ elif current_platform.is_xpu():
...
@@ -27,10 +31,10 @@ elif current_platform.is_xpu():
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
maybe_transfer_kv_layer
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
layer_name
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -56,7 +60,7 @@ def sparse_attn_indexer(
...
@@ -56,7 +60,7 @@ def sparse_attn_indexer(
)
)
return
sparse_attn_indexer_fake
(
return
sparse_attn_indexer_fake
(
hidden_states
,
hidden_states
,
k_cache_prefix
,
layer_name
,
kv_cache
,
kv_cache
,
q_fp8
,
q_fp8
,
k
,
k
,
...
@@ -69,7 +73,7 @@ def sparse_attn_indexer(
...
@@ -69,7 +73,7 @@ def sparse_attn_indexer(
total_seq_lens
,
total_seq_lens
,
topk_indices_buffer
,
topk_indices_buffer
,
)
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
...
@@ -282,7 +286,7 @@ def sparse_attn_indexer(
...
@@ -282,7 +286,7 @@ def sparse_attn_indexer(
def
sparse_attn_indexer_fake
(
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
layer_name
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment