Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4612aad6
"vscode:/vscode.git/clone" did not exist on "68b0a6c1baf32221fa8ed98a941f611cbff2cdb1"
Commit
4612aad6
authored
Dec 27, 2025
by
Your Name
Browse files
[P/D][Feat]支持dp并行
parent
cd42bf87
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
874 additions
and
311 deletions
+874
-311
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
...ed_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
+503
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+2
-1
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+89
-106
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+264
-201
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+13
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-0
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
0 → 100644
View file @
4612aad6
This diff is collapsed.
Click to expand it.
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
4612aad6
...
...
@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
,
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
...
...
@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return
connector_cls
(
config
,
role
)
return
connector_cls
(
config
,
role
,
dp_rank
)
# Register various connectors here.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
4612aad6
...
...
@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
import
regex
as
re
import
torch
import
os
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
import
(
P2pNcclEngine
)
from
vllm.distributed.parallel_state
import
get_world_group
P2pNcclEngine
,
RemoteAddr
)
from
vllm.distributed.parallel_state
import
get_world_group
,
get_dp_group
,
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
import
zmq
import
msgpack
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class
P2pNcclConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
...
...
@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
local_rank
=
self
.
_local_rank
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
hostname
=
""
,
port_offset
=
self
.
_rank
,
dp_rank
=
self
.
_dp_rank
,
pp_rank
=
self
.
_pp_rank
,
tp_rank
=
self
.
_tp_rank
,
dp_size
=
self
.
_dp_size
,
pp_size
=
self
.
_pp_size
,
tp_size
=
self
.
_tp_size
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
multiple_machines
=
1
if
self
.
num_card
>
8
else
0
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
if
self
.
is_producer
and
self
.
multiple_machines
==
1
:
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
...
...
@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
dp_rank
=
dp_rank
proxy_ip
=
self
.
config
.
get_from_extra_config
(
"proxy_ip"
,
""
)
proxy_port
=
self
.
config
.
get_from_extra_config
(
"proxy_port"
,
""
)
if
proxy_ip
==
""
or
proxy_port
==
""
:
self
.
proxy_address
=
""
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
self
.
context
=
zmq
.
Context
()
req_sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
req_sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"
{
self
.
http_address
}
_rank
{
self
.
dp_rank
}
"
)
req_sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
self
.
req_sock
=
req_sock
def
get_ip_value
(
self
,
key
):
return
self
.
ip_map
.
get
(
key
)
def
register_req
(
self
,
request_id
:
str
)
:
data
=
{
"type"
:
"Req"
,
"instance_type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"request_id"
:
request_id
,
"dp_rank"
:
self
.
dp_rank
}
self
.
req_sock
.
send
(
msgpack
.
dumps
(
data
))
# ==============================
# Worker-side methods
# ==============================
...
...
@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2
,
num_pages
*
page_size
,
-
1
)
inject_start_index
=
0
for
num
in
range
(
self
.
p2p_nccl_engine
.
tensor_split_num
):
req_layer
=
f
"
{
request
.
request_id
}
#
{
layer_name
}
"
with
self
.
p2p_nccl_engine
.
recv_store_cv
:
while
req_layer
not
in
self
.
p2p_nccl_engine
.
recv_split_nums
:
self
.
p2p_nccl_engine
.
recv_store_cv
.
wait
()
split_num
=
self
.
p2p_nccl_engine
.
recv_split_nums
.
get
(
req_layer
)
for
num
in
range
(
split_num
):
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
))
...
...
@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
)
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
...
...
@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
...
...
@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
slot_mapping
=
request
.
slot_mapping
if
request
.
slot_mapping_device
is
None
:
request
.
slot_mapping_device
=
\
...
...
@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
slot_mapping
=
request
.
slot_mapping_device
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
.
record
()
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
\
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
),
tbo_evt
)
else
:
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
),
tbo_evt
)
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
pending
=
False
with
self
.
p2p_nccl_engine
.
req_status_cv
:
if
request_id
not
in
self
.
p2p_nccl_engine
.
req_status
:
pending
=
True
if
pending
:
self
.
p2p_nccl_engine
.
pending_tensor
(
request_id
,
layer_name
,
(
kv_layer
,
slot_mapping
),
tbo_evt
)
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
else
:
req_data
=
self
.
p2p_nccl_engine
.
req_status
[
request_id
]
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
for
i
in
range
(
req_data
.
dst_num
):
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
i
),
tbo_evt
)
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
(
kv_layer
,
slot_mapping
),
remote_addr
,
tbo_evt
)
# self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
# (kv_layer, slot_mapping), remote_address, tbo_evt)
else
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
multiple_machines_p
and
self
.
multiple_machines_d
):
ip_second
=
self
.
get_ip_value
(
ip
)
if
(
self
.
pp_size
==
1
):
if
self
.
_rank
<
8
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
+
8
))
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
else
:
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
if
(
self
.
pp_size
==
2
):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
pending
=
False
with
self
.
p2p_nccl_engine
.
req_status_cv
:
if
request_id
not
in
self
.
p2p_nccl_engine
.
req_status
:
pending
=
True
if
pending
:
self
.
p2p_nccl_engine
.
pending_tensor
(
request_id
,
layer_name
,
kv_cache
)
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
else
:
req_data
=
self
.
p2p_nccl_engine
.
req_status
[
request_id
]
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
for
i
in
range
(
req_data
.
dst_num
):
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
else
:
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
self
.
p2p_nccl_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
is_mla
)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
logger
.
error
(
"Error: not support!!!!!!"
)
kv_cache
,
remote_addr
)
def
wait_for_save
(
self
):
pass
# if self.is_producer:
...
...
@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
# assert req_id in self.chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
assert
req_id
in
self
.
chunked_prefill
block_ids
=
new_block_ids
[
0
]
if
not
resumed_from_preemption
:
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
4612aad6
This diff is collapsed.
Click to expand it.
vllm/v1/core/sched/scheduler.py
View file @
4612aad6
...
...
@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
,
dp_rank
=
self
.
parallel_config
.
data_parallel_rank
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
...
...
@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
break
request
=
self
.
waiting
.
peek_request
()
if
request
.
is_finished
():
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
...
...
@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
num_new_tokens
>
token_budget
:
break
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
...
...
@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
break
request
=
self
.
waiting
.
peek_request
()
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
...
@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
num_new_tokens
>
token_budget
:
break
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
...
...
@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
spec_token_ids
is
not
None
and
(
self
.
connector
is
None
or
not
self
.
connector
.
is_producer
)
:
if
self
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
...
...
vllm/v1/engine/core.py
View file @
4612aad6
...
...
@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
if
isinstance
(
request
,
EngineCoreRequest
)
and
self
.
scheduler
.
connector
is
not
None
:
if
request_type
==
EngineCoreRequestType
.
ADD
:
self
.
scheduler
.
connector
.
register_req
(
request
.
request_id
)
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
coord_output_path
:
Optional
[
str
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment