Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ecb8be9
Commit
1ecb8be9
authored
Mar 02, 2026
by
xuxz
Browse files
[PD]类与函数数据结构优化参数调整 && 支持小模型pd推理
parent
161789cb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
67 deletions
+59
-67
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+59
-67
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
1ecb8be9
...
...
@@ -36,28 +36,20 @@ logger = init_logger(__name__)
class
ReqMeta
:
# Request Id
request_id
:
str
# Request
token
s
token
_ids
:
torch
.
Tensor
#
Slot mappings, should have the same length as
token
_id
s
slot_mapping
:
torch
.
Tensor
slot_mapping_device
:
torch
.
Tensor
=
None
# Request
block id
s
block
_ids
:
torch
.
Tensor
#
Request num
tokens
num_tokens
:
int
@
staticmethod
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
)
->
"ReqMeta"
:
valid_num_tokens
=
len
(
token_ids
)
token_ids_tensor
=
torch
.
tensor
(
token_ids
)
block_ids_tensor
=
torch
.
tensor
(
block_ids
)
num_blocks
=
block_ids_tensor
.
shape
[
0
]
block_offsets
=
torch
.
arange
(
0
,
block_size
)
slot_mapping
=
block_offsets
.
reshape
((
1
,
block_size
))
+
\
block_ids_tensor
.
reshape
((
num_blocks
,
1
))
*
block_size
slot_mapping
=
slot_mapping
.
flatten
()[:
valid_num_tokens
]
return
ReqMeta
(
request_id
=
request_id
,
token_ids
=
token
_ids_tensor
,
slot_mapping
=
slot_mapping
,
block_ids
=
block
_ids_tensor
,
num_tokens
=
len
(
token_ids
)
,
)
...
...
@@ -222,9 +214,9 @@ class DuSwiftConnector(KVConnectorBase_V1):
if
attn_metadata
is
None
:
return
def
inject_kv_into_layer
(
dst_kv_cache_
layer
:
torch
.
Tensor
,
src_
kv_cache
:
torch
.
Tensor
,
s
lo
t_mapping
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
b
lo
ck_ids
:
torch
.
Tensor
,
request_id
:
str
,
)
->
None
:
"""Inject the KV cache into the layer.
...
...
@@ -240,45 +232,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
[num_tokens].
request_id (str): request id for log
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
0
)
num_token
=
src_kv_cache
.
shape
[
0
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[
slot_mapping
,
...]
=
src_kv_cache
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
(
not
isinstance
(
layer
,
tuple
))):
# MLA or FlashInfer
num_block
=
kv_cache
.
shape
[
0
]
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
0
)
if
len
(
block_ids
)
==
num_block
:
layer
[
block_ids
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
layer
[
block_ids
[:
num_block
],
...]
=
kv_cache
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
logger
.
warning
(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
num_block
,
request_id
)
#elif layer.shape[0] == 2: # FlashAttention
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
1
)
num_token
=
src_kv_cache
.
shape
[
1
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[:,
slot_mapping
,
...]
=
src_kv_cache
num_block
=
kv_cache
.
shape
[
1
]
#self.check_tensors_except_dim(layer, kv_cache, 1)
if
len
(
block_ids
)
==
num_block
:
#layer[:, block_ids, ...] = kv_cache
k_
=
kv_cache
[
0
].
permute
(
0
,
2
,
1
,
3
)
v_
=
kv_cache
[
1
].
permute
(
0
,
2
,
3
,
1
)
layer
[
0
][
block_ids
,
...]
=
k_
layer
[
1
][
block_ids
,
...]
=
v_
else
:
dst_kv_cache_layer
[:,
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
#layer[:, block_ids[:num_block], ...] = kv_cache
k_
=
kv_cache
[
0
].
permute
(
0
,
2
,
1
,
3
)
v_
=
kv_cache
[
1
].
permute
(
0
,
2
,
3
,
1
)
layer
[
0
][
block_ids
[:
num_block
],
...]
=
k_
layer
[
1
][
block_ids
[:
num_block
],
...]
=
v_
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
num_block
,
request_id
)
# Get the metadata
metadata
:
KVConnectorMetadata
=
\
...
...
@@ -300,18 +286,17 @@ class DuSwiftConnector(KVConnectorBase_V1):
if
kv_cache
is
None
:
continue
kv_cache_layer
=
kv_cache
[
\
forward_context
.
virtual_engine
]
layer
=
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
self
.
du_swift_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
)
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
logger
.
warning
(
"🚧kv_cache is None, %s"
,
request
.
request_id
)
continue
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
request
.
request_id
)
inject_kv_into_layer
(
layer
,
kv_cache
,
request
.
block_ids
,
request
.
request_id
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
if
tensor_id
in
self
.
du_swift_engine
.
recv_store
:
tensor
=
self
.
du_swift_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
...
...
@@ -359,20 +344,27 @@ class DuSwiftConnector(KVConnectorBase_V1):
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
s
lo
t_mapping
:
torch
.
Tensor
,
b
lo
ck_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
not
isinstance
(
layer
,
tuple
)):
# MLA or FlashInfer
return
layer
[
block_ids
,
...]
#if layer.shape[0] == 2: # FlashAttention
# return layer[:, block_ids, ...]
else
:
k
=
layer
[
0
]
#(num_blocks, num_kv_heads, block_size, head_size)
v
=
layer
[
1
]
#(num_blocks, num_kv_heads, head_size, block_size)
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
permute
(
0
,
3
,
1
,
2
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
0
).
contiguous
()
return
kv
[:,
block_ids
,
...]
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
DuSwiftConnectorMetadata
)
...
...
@@ -380,7 +372,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
s
lo
t_mapping
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
b
lo
ck_ids
)
pending
=
False
with
self
.
du_swift_engine
.
req_status_cv
:
if
request_id
not
in
self
.
du_swift_engine
.
req_status
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment