Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2241085d
"tests/vscode:/vscode.git/clone" did not exist on "fb4e8bf442c53a211d297d31f0381f16c40b1240"
Commit
2241085d
authored
Dec 24, 2025
by
Your Name
Browse files
[PD][Feat]支持fa_pa kvcahe类型模型推理
parent
1871c26c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
9 deletions
+26
-9
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+26
-9
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
2241085d
...
@@ -181,25 +181,34 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -181,25 +181,34 @@ class P2pNcclConnector(KVConnectorBase_V1):
None. The function modifies `layer` in-place.
None. The function modifies `layer` in-place.
"""
"""
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
layer
.
shape
[
1
]
==
2
):
# MLA or FlashInfer
or
(
not
isinstance
(
layer
,
tuple
))
):
# MLA or FlashInfer
num_block
=
kv_cache
.
shape
[
0
]
num_block
=
kv_cache
.
shape
[
0
]
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
0
)
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
0
)
if
len
(
block_ids
)
==
num_block
:
if
len
(
block_ids
)
==
num_block
:
layer
[
block_ids
,
...]
=
kv_cache
layer
[
block_ids
,
...]
=
kv_cache
else
:
else
:
layer
[
block_ids
[:
num_block
],
...]
=
kv_cache
layer
[
block_ids
[:
num_block
],
...]
=
kv_cache
logger
.
warning
(
logger
.
warning
(
"🚧kv_cache does not match, block_ids:%d, "
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
num_block
,
request_id
)
num_block
,
request_id
)
#elif layer.shape[0] == 2: # FlashAttention
el
if
layer
.
shape
[
0
]
==
2
:
# FlashAttention
el
se
:
num_block
=
kv_cache
.
shape
[
1
]
num_block
=
kv_cache
.
shape
[
1
]
self
.
check_tensors_except_dim
(
layer
,
kv_cache
,
1
)
#
self.check_tensors_except_dim(layer, kv_cache, 1)
if
len
(
block_ids
)
==
num_block
:
if
len
(
block_ids
)
==
num_block
:
layer
[:,
block_ids
,
...]
=
kv_cache
#layer[:, block_ids, ...] = kv_cache
k_
=
kv_cache
[
0
].
permute
(
0
,
2
,
1
,
3
)
v_
=
kv_cache
[
1
].
permute
(
0
,
2
,
3
,
1
)
layer
[
0
][
block_ids
,
...]
=
k_
layer
[
1
][
block_ids
,
...]
=
v_
else
:
else
:
layer
[:,
block_ids
[:
num_block
],
...]
=
kv_cache
#layer[:, block_ids[:num_block], ...] = kv_cache
k_
=
kv_cache
[
0
].
permute
(
0
,
2
,
1
,
3
)
v_
=
kv_cache
[
1
].
permute
(
0
,
2
,
3
,
1
)
layer
[
0
][
block_ids
[:
num_block
],
...]
=
k_
layer
[
1
][
block_ids
[:
num_block
],
...]
=
v_
logger
.
warning
(
logger
.
warning
(
"🚧kv_cache does not match, block_ids:%d, "
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
"num_block:%d, request_id:%s"
,
len
(
block_ids
),
...
@@ -304,11 +313,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -304,11 +313,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
Returns None if the layout is unsupported.
Returns None if the layout is unsupported.
"""
"""
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
if
(
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
layer
.
shape
[
1
]
==
2
):
# MLA or FlashInfer
or
not
isinstance
(
layer
,
tuple
)
):
# MLA or FlashInfer
return
layer
[
block_ids
,
...]
return
layer
[
block_ids
,
...]
if
layer
.
shape
[
0
]
==
2
:
# FlashAttention
#if layer.shape[0] == 2: # FlashAttention
return
layer
[:,
block_ids
,
...]
# return layer[:, block_ids, ...]
else
:
k
=
layer
[
0
]
#(num_blocks, num_kv_heads, block_size, head_size)
v
=
layer
[
1
]
#(num_blocks, num_kv_heads, head_size, block_size)
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
permute
(
0
,
3
,
1
,
2
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
0
).
contiguous
()
return
kv
[:,
block_ids
,
...]
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment