Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b98167fc
Commit
b98167fc
authored
Apr 09, 2026
by
xiabo
Committed by
zhangzbb
Apr 10, 2026
Browse files
同步官方v0.15.1的kvcache的处理方式。可以参照官方发pr:
https://github.com/vllm-project/vllm/pull/23536/changes
parent
ce47a56e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
78 deletions
+78
-78
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+78
-78
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
b98167fc
...
...
@@ -34,28 +34,20 @@ logger = init_logger(__name__)
class
ReqMeta
:
# Request Id
request_id
:
str
# Request tokens
token_ids
:
torch
.
Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping
:
torch
.
Tensor
slot_mapping_device
:
torch
.
Tensor
=
None
# Request block ids
block_ids
:
torch
.
Tensor
# Request num tokens
num_tokens
:
int
@
staticmethod
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
)
->
"ReqMeta"
:
valid_num_tokens
=
len
(
token_ids
)
token_ids_tensor
=
torch
.
tensor
(
token_ids
)
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
)
->
"ReqMeta"
:
block_ids_tensor
=
torch
.
tensor
(
block_ids
)
num_blocks
=
block_ids_tensor
.
shape
[
0
]
block_offsets
=
torch
.
arange
(
0
,
block_size
)
slot_mapping
=
block_offsets
.
reshape
((
1
,
block_size
))
+
\
block_ids_tensor
.
reshape
((
num_blocks
,
1
))
*
block_size
slot_mapping
=
slot_mapping
.
flatten
()[:
valid_num_tokens
]
return
ReqMeta
(
request_id
=
request_id
,
token_ids
=
token
_ids_tensor
,
slot_mapping
=
slot_mapping
,
block_ids
=
block
_ids_tensor
,
num_tokens
=
len
(
token_ids
)
,
)
...
...
@@ -74,7 +66,8 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
block_size
:
int
,
)
->
None
:
self
.
requests
.
append
(
ReqMeta
.
make_meta
(
request_id
,
token_ids
,
block_ids
,
block_size
))
ReqMeta
.
make_meta
(
request_id
,
token_ids
,
block_ids
,
block_size
)
)
class
DuSwiftConnector
(
KVConnectorBase_V1
):
...
...
@@ -190,63 +183,62 @@ class DuSwiftConnector(KVConnectorBase_V1):
if
attn_metadata
is
None
:
return
def
inject_kv_into_layer
(
dst_kv_cache_
layer
:
torch
.
Tensor
,
src_
kv_cache
:
torch
.
Tensor
,
s
lo
t_mapping
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
b
lo
ck_ids
:
torch
.
Tensor
,
request_id
:
str
,
)
->
None
:
"""Inject the KV cache into the layer.
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
request_id (str): request id for log
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
dst_kv_cache_layer
.
ndim
==
3
:
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
0
)
num_token
=
src_kv_cache
.
shape
[
0
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[
slot_mapping
,
...]
=
src_kv_cache
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
layer
.
ndim
==
3
:
num_block
=
kv_cache
.
shape
[
0
]
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
0
)
if
len
(
block_ids
)
==
num_block
:
layer
[
block_ids
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
layer
[
block_ids
[:
num_block
],
...]
=
kv_cache
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
num_block
,
request_id
,
)
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
1
)
num_token
=
src_kv_cache
.
shape
[
1
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[:,
slot_mapping
,
...]
=
src_kv_cache
num_block
=
kv_cache
.
shape
[
1
]
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
1
)
if
len
(
block_ids
)
==
num_block
:
layer
[:,
block_ids
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[:,
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
layer
[:,
block_ids
[:
num_block
],
...]
=
kv_cache
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
num_block
,
request_id
,
)
# Get the metadata
metadata
:
KVConnectorMetadata
=
\
...
...
@@ -280,7 +272,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request
.
request_id
)
continue
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
s
lo
t_mapping
,
request
.
request_id
)
request
.
b
lo
ck_ids
,
request
.
request_id
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
if
tensor_id
in
self
.
du_swift_engine
.
recv_store
:
tensor
=
self
.
du_swift_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
...
...
@@ -383,20 +375,28 @@ class DuSwiftConnector(KVConnectorBase_V1):
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
s
lo
t_mapping
:
torch
.
Tensor
,
b
lo
ck_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract the KV cache from the layer.
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
layer
.
ndim
==
3
:
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
return
layer
[
block_ids
,
...]
return
layer
[:,
block_ids
,
...]
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
DuSwiftConnectorMetadata
)
...
...
@@ -443,7 +443,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
s
lo
t_mapping
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
b
lo
ck_ids
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment