Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
73032f48
Commit
73032f48
authored
Mar 11, 2026
by
xuxz
Browse files
[PD]回退p2pncclconnector
parent
a997359c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
540 additions
and
48 deletions
+540
-48
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+269
-31
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+270
-16
vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
...ted/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
+1
-1
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
73032f48
...
...
@@ -6,17 +6,19 @@ from typing import TYPE_CHECKING, Any, Optional
import
regex
as
re
import
torch
import
os
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
import
(
P2pNcclEngine
)
P2pNcclEngine
,
RemoteAddr
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
@@ -35,6 +37,7 @@ class ReqMeta:
token_ids
:
torch
.
Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping
:
torch
.
Tensor
slot_mapping_device
:
torch
.
Tensor
=
None
@
staticmethod
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
...
...
@@ -54,7 +57,7 @@ class ReqMeta:
slot_mapping
=
slot_mapping
,
)
@
dataclass
class
P2pNcclConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
...
...
@@ -87,13 +90,77 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_local_rank
=
get_world_group
().
local_rank
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_rank
=
get_dp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_rank
=
get_pp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_rank
=
get_tp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_size
=
get_dp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_size
=
get_pp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
local_rank
=
self
.
_local_rank
,
config
=
self
.
config
,
hostname
=
""
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
dp_rank
=
self
.
_dp_rank
,
pp_rank
=
self
.
_pp_rank
,
tp_rank
=
self
.
_tp_rank
,
dp_size
=
self
.
_dp_size
,
pp_size
=
self
.
_pp_size
,
tp_size
=
self
.
_tp_size
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
model_config
=
vllm_config
.
model_config
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
"num_hidden_layers"
,
0
)
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
if
not
config_file
:
print
(
"Warning: Please set the IPVNet FILE environment variable for cross machine recognition of the second IP address"
)
return
try
:
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
for
line_num
,
line
in
enumerate
(
file
,
1
):
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
'#'
):
ips
=
line
.
split
()
if
len
(
ips
)
==
2
:
first_ip
,
second_ip
=
ips
if
first_ip
not
in
self
.
ip_map
:
self
.
ip_map
[
first_ip
]
=
second_ip
else
:
print
(
f
"warning: num
{
line_num
}
Incorrect format :
{
line
}
"
)
except
Exception
as
e
:
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
def
get_ip_value
(
self
,
key
):
return
self
.
ip_map
.
get
(
key
)
# ==============================
# Worker-side methods
...
...
@@ -116,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only consumer/decode loads KV Cache
if
self
.
is_producer
:
return
assert
self
.
p2p_nccl_engine
is
not
None
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
return
def
inject_kv_into_layer
(
dst_kv_cache_layer
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
...
...
@@ -143,7 +208,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
request_id (str): request id for log
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
:
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
...
...
@@ -193,20 +258,95 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Load the KV for each request each layer
for
request
in
metadata
.
requests
:
for
layer_name
in
forward_context
.
no_compile_layers
:
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache_layer
=
attn_layer
.
kv_cache
[
\
forward_context
.
virtual_engine
]
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
)
layer
=
forward_context
.
no_compile_layers
[
layer_name
]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache
=
getattr
(
layer
,
'kv_cache'
,
None
)
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
continue
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
request
.
request_id
)
kv_cache_layer
=
kv_cache
[
\
forward_context
.
virtual_engine
]
if
not
envs
.
VLLM_P2P_ASYNC
:
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
)
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
continue
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
request
.
request_id
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
self
.
p2p_nccl_engine
.
recv_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
p2p_nccl_engine
.
pool
.
free
(
addr
)
else
:
dst_kv_cache_layer_shape
=
kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
assert
kv_cache_layer
.
is_contiguous
()
dst_kv_cache_layer
=
kv_cache_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
assert
kv_cache_layer
.
is_contiguous
()
dst_kv_cache_layer
=
kv_cache_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
inject_start_index
=
0
for
num
in
range
(
self
.
p2p_nccl_engine
.
tensor_split_num
):
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
))
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
continue
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
num_token
=
kv_cache
.
shape
[
0
]
if
len
(
request
.
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[
request
.
slot_mapping
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[
request
.
slot_mapping
[
inject_start_index
:
inject_start_index
+
num_token
],
...]
=
kv_cache
else
:
num_token
=
kv_cache
.
shape
[
1
]
if
len
(
request
.
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[:,
request
.
slot_mapping
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[:,
request
.
slot_mapping
[
inject_start_index
:
inject_start_index
+
num_token
],
...]
=
kv_cache
inject_start_index
+=
num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
)
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
self
.
p2p_nccl_engine
.
recv_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
p2p_nccl_engine
.
pool
.
free
(
addr
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""Blocking until the KV for a specific layer is loaded into vLLM's
...
...
@@ -238,6 +378,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
...
...
@@ -246,7 +388,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...
...
@@ -257,18 +399,112 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
P2pNcclConnectorMetadata
)
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
slot_mapping
=
request
.
slot_mapping
if
request
.
slot_mapping_device
is
None
:
request
.
slot_mapping_device
=
\
request
.
slot_mapping
.
pin_memory
().
to
(
device
=
kv_layer
.
device
,
non_blocking
=
True
)
slot_mapping
=
request
.
slot_mapping_device
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
.
record
()
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
\
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
),
tbo_evt
)
else
:
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
),
tbo_evt
)
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
i
),
tbo_evt
)
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
else
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
multiple_machines_p
and
self
.
multiple_machines_d
):
ip_second
=
self
.
get_ip_value
(
ip
)
if
(
self
.
pp_size
==
1
):
if
self
.
_rank
<
8
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
+
8
))
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
else
:
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
if
(
self
.
pp_size
==
2
):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
else
:
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self
.
p2p_nccl_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
is_mla
)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
logger
.
error
(
"Error: not support!!!!!!"
)
def
wait_for_save
(
self
):
if
self
.
is_producer
:
assert
self
.
p2p_nccl_engine
is
not
None
self
.
p2p_nccl_engine
.
wait_for_sent
()
pass
# if self.is_producer:
# assert self.p2p_nccl_engine is not None
# self.p2p_nccl_engine.wait_for_sent()
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
],
...
...
@@ -382,7 +618,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
assert
req_id
in
self
.
chunked_prefill
# assert req_id in self.chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
block_ids
=
new_block_ids
[
0
]
if
not
resumed_from_preemption
:
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
...
...
@@ -482,4 +720,4 @@ class P2pNcclConnector(KVConnectorBase_V1):
for
i
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shape1
,
shape2
))
if
i
!=
dim
):
raise
NotImplementedError
(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs."
)
"and others will be supported in future PRs."
)
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
73032f48
...
...
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
msgpack
import
torch
import
zmq
import
regex
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
...
...
@@ -20,6 +21,13 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool
import
(
# noqa: E501
TensorMemoryPool
)
from
vllm.utils
import
current_stream
,
get_ip
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
dataclasses
import
dataclass
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.config
import
ModelConfig
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
...
...
@@ -28,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB
=
32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@
contextmanager
def
set_p2p_nccl_context
(
num_channels
:
str
):
...
...
@@ -59,17 +72,37 @@ def set_p2p_nccl_context(num_channels: str):
os
.
environ
.
pop
(
var
,
None
)
@
dataclass
class
RemoteAddr
:
pd_pair_id
:
str
=
""
zmq_address
:
str
=
""
comm_rank
:
int
=
0
class
P2pNcclEngine
:
def
__init__
(
self
,
local_rank
:
int
,
port_offset
:
int
,
config
:
KVTransferConfig
,
hostname
:
str
=
""
,
port_offset
:
int
=
0
,
model_config
:
ModelConfig
,
dp_rank
:
int
=
0
,
pp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
dp_size
:
int
=
0
,
pp_size
:
int
=
0
,
tp_size
:
int
=
0
,
library_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
config
=
config
self
.
model_config
=
model_config
self
.
rank
=
port_offset
self
.
local_rank
=
local_rank
self
.
dp_rank
=
dp_rank
self
.
pp_rank
=
pp_rank
self
.
tp_rank
=
tp_rank
self
.
dp_size
=
dp_size
self
.
pp_size
=
pp_size
self
.
tp_size
=
tp_size
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
...
...
@@ -95,7 +128,7 @@ class P2pNcclEngine:
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
self
.
_hostname
=
hostname
self
.
_hostname
=
get_ip
()
self
.
_port
=
port
# Each card corresponds to a ZMQ address.
...
...
@@ -128,6 +161,10 @@ class P2pNcclEngine:
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
self
.
p2p_async_kv_tokens
=
envs
.
VLLM_P2P_BUF_TOKENS
self
.
p2p_async_buf
=
None
self
.
tensor_split_num
:
int
=
0
mem_pool_size_gb
=
self
.
config
.
get_from_extra_config
(
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
)
...
...
@@ -167,11 +204,16 @@ class P2pNcclEngine:
self
.
_listener_thread
.
start
()
self
.
_ping_thread
=
None
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
if
self
.
multiple_machines
:
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
else
:
if
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping_new
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
logger
.
info
(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
...
...
@@ -179,6 +221,21 @@ class P2pNcclEngine:
self
.
http_address
,
self
.
zmq_address
,
self
.
proxy_address
,
self
.
send_type
,
self
.
buffer_size_threshold
,
self
.
nccl_num_channels
)
def
_create_connect_new
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt
(
zmq
.
SNDHWM
,
10000
)
sock
.
setsockopt
(
zmq
.
RCVHWM
,
5000
)
sock
.
setsockopt
(
zmq
.
LINGER
,
0
)
sock
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"P-
{
self
.
zmq_address
}
"
)
sock
.
connect
(
f
"tcp://
{
remote_address
}
"
)
self
.
socks
[
remote_address
]
=
sock
return
self
.
socks
[
remote_address
]
def
_create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
...
...
@@ -206,11 +263,73 @@ class P2pNcclEngine:
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
def
get_send_queue_items
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
)
->
list
[
any
]:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
remote_ip
,
remote_port
=
self
.
parse_request_id
(
request_id
,
True
)
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
pd_pair_id
=
p_ip
+
":"
+
str
(
p_port
)
+
"_"
+
remote_ip
+
":"
+
str
(
remote_port
)
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
remote_addr
=
RemoteAddr
(
pd_pair_id
,
remote_address
,
self
.
rank
+
self
.
pp_size
*
self
.
tp_size
)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return
[(
tensor_id
,
remote_addr
,
tensor
)]
if
not
is_mla
:
logger
.
error
(
" P2PNCCL only support mla model symmetric PP/TP!!!!"
)
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
remote_port_offset
=
remote_pp_rank
*
self
.
remote_tp_size
+
d_tp_rank
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
remote_port_offset
)
remote_addr
=
RemoteAddr
(
pd_pair_id
,
remote_address
,
remote_port_offset
+
self
.
pp_size
*
self
.
tp_size
)
logger
.
debug
(
"Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)"
,
tensor_id
,
tensor
.
shape
,
self
.
pp_rank
,
self
.
tp_rank
,
remote_address
,
remote_pp_rank
,
self
.
rank
*
mul_tp
+
self
.
rank
,
self
.
rank
,
remote_port_offset
+
self
.
pp_size
*
self
.
tp_size
)
items
.
append
([
tensor_id
,
remote_addr
,
tensor
])
return
items
def
send_tensor_new
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
=
False
,
)
->
bool
:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
_send_sync_new
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
):
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
if
self
.
send_type
==
"GET"
:
logger
.
error
(
" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!"
)
def
send_tensor
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
...
...
@@ -250,6 +369,53 @@ class P2pNcclEngine:
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
def
p2p_async_send_tensor
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store_cv
.
notify
()
return
True
else
:
if
self
.
send_type
==
"PUT"
:
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
elif
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
kv_layer
,
slot_mapping
=
tensor
# tesor (kv_layer, slot_mapping)
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
])
self
.
send_queue_cv
.
notify
()
else
:
# GET
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
while
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
oldest_tenser_id
=
next
(
iter
(
self
.
send_store
))
oldest_tenser
=
self
.
send_store
.
pop
(
oldest_tenser_id
)
oldest_tenser_size
=
oldest_tenser
.
element_size
(
)
*
oldest_tenser
.
numel
()
self
.
buffer_size
-=
oldest_tenser_size
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
buffer_size
+=
tensor_size
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
def
recv_tensor
(
self
,
...
...
@@ -327,6 +493,8 @@ class P2pNcclEngine:
self
.
zmq_address
,
remote_address
.
decode
(),
rank
)
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
self
.
tensor_split_num
=
data
[
"tensor_split_num"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
...
...
@@ -343,10 +511,6 @@ class P2pNcclEngine:
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
,
addr
)
else
:
self
.
buffer_size
+=
tensor_size
...
...
@@ -363,7 +527,56 @@ class P2pNcclEngine:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"PUT_NEW"
:
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
self
.
tensor_split_num
=
data
[
"tensor_split_num"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"0"
])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm
,
rank
=
self
.
comms
[
data
[
"pd_pair_id"
]]
self
.
_recv
(
comm
,
tensor
,
int
(
data
[
"comm_rank"
]),
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
else
:
self
.
buffer_size
+=
tensor_size
except
torch
.
cuda
.
OutOfMemoryError
:
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"1"
])
tensor
=
None
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
)
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"comm_init"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
int
(
data
[
"rank"
])
world_size
=
int
(
data
[
"world_size"
])
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
world_size
,
unique_id
,
rank
)
self
.
comms
[
data
[
"pd_pair_id"
]]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
data
[
"pd_pair_id"
],
rank
)
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
with
self
.
send_store_cv
:
...
...
@@ -410,10 +623,21 @@ class P2pNcclEngine:
with
self
.
send_queue_cv
:
while
not
self
.
send_queue
:
self
.
send_queue_cv
.
wait
()
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
=
self
.
send_queue
.
popleft
()
else
:
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
if
not
self
.
send_queue
:
self
.
send_queue_cv
.
notify
()
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
_send_kv_p2p_sync
(
tensor_id
,
kv_layer
,
slot_mapping
,
remote_address
)
else
:
if
self
.
multiple_machines
:
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
else
:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self
.
_send_sync_new
(
tensor_id
,
tensor
,
remote_address
)
def
wait_for_sent
(
self
):
if
self
.
send_type
==
"PUT_ASYNC"
:
...
...
@@ -518,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
#
logger.info(f"""_send_sync_new:{data}""")
logger
.
info
(
f
"""_send_sync_new:
{
data
}
"""
)
sock
.
send
(
msgpack
.
dumps
(
data
))
response
=
sock
.
recv
()
...
...
@@ -627,6 +851,36 @@ class P2pNcclEngine:
sock
.
send
(
msgpack
.
dumps
(
data
))
time
.
sleep
(
3
)
def
_ping_new
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
if
self
.
rank
==
0
:
data
=
{
"type"
:
"P_init"
if
self
.
config
.
is_kv_producer
else
"D_init"
,
"http_address"
:
self
.
http_address
,
"zmq_address"
:
self
.
zmq_address
,
"dp_size"
:
self
.
dp_size
,
"pp_size"
:
self
.
pp_size
,
"tp_size"
:
self
.
tp_size
}
# logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"dp_rank"
:
self
.
dp_rank
,
"pp_rank"
:
self
.
pp_rank
,
"tp_rank"
:
self
.
tp_rank
,
"zmq_address"
:
self
.
zmq_address
}
# while True:
# logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
# time.sleep(3)
def
_send
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
...
...
@@ -727,4 +981,4 @@ class P2pNcclEngine:
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
f
"Request id
{
request_id
}
does not contain hostname and port"
)
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py
View file @
73032f48
...
...
@@ -63,7 +63,7 @@ class TensorMemoryPool:
than min_block_size
"""
def
__init__
(
self
,
max_block_size
:
int
,
min_block_size
:
int
=
5
12
):
def
__init__
(
self
,
max_block_size
:
int
,
min_block_size
:
int
=
12
8
):
if
max_block_size
<=
0
or
min_block_size
<=
0
:
raise
ValueError
(
"Block sizes must be positive"
)
if
max_block_size
<
min_block_size
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment