Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56cc2ac8
Commit
56cc2ac8
authored
Mar 09, 2026
by
xuxz
Browse files
[PD]修改代理
parent
56fef1c3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
243 additions
and
56 deletions
+243
-56
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
...gated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
+243
-56
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
View file @
56cc2ac8
...
...
@@ -4,35 +4,87 @@
import
os
import
socket
import
threading
import
time
import
uuid
from
typing
import
Any
import
aiohttp
import
msgpack
import
zmq
from
typing
import
Any
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
collections
import
deque
,
defaultdict
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@
dataclass
class
Instance
:
ins_type
:
str
=
"P"
http_address
:
str
=
""
zmq_address
:
str
=
""
p_unique_id
:
bytes
=
b
""
dp_size
:
int
=
0
pp_size
:
int
=
0
tp_size
:
int
=
0
# [dp, pp, tp] : zmq_address
rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
str
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
# [dp, pp, tp] : global rank
comm_rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
int
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
count
=
0
prefill_instances
:
dict
[
str
,
Any
]
=
{}
# http_address: (zmq_address, stamp)
decode_instances
:
dict
[
str
,
Any
]
=
{}
# http_address: (zmq_address, stamp)
def
count_rank_table_elements
(
self
):
count
=
0
for
first_dict
in
self
.
rank_table
.
values
():
for
second_dict
in
first_dict
.
values
():
count
+=
len
(
second_dict
)
return
count
def
is_ready
(
self
):
world_size
=
self
.
dp_size
*
self
.
pp_size
*
self
.
tp_size
inited_rank
=
self
.
count_rank_table_elements
()
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
# return all_ranks_ready and self.p_unique_id != b""
return
all_ranks_ready
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
count
=
0
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
DEFAULT_PING_SECONDS
=
5
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
ready_prefill_ins
:
list
[
str
]
=
[]
ready_decode_ins
:
list
[
str
]
=
[]
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
def
_remove_oldest_instances
(
instances
:
dict
[
str
,
Any
])
->
None
:
oldest_key
=
next
(
iter
(
instances
),
None
)
while
oldest_key
is
not
None
:
value
=
instances
[
oldest_key
]
if
value
[
1
]
>
time
.
time
():
break
print
(
f
"🔴Remove [HTTP:
{
oldest_key
}
, ZMQ:
{
value
[
0
]
}
, stamp:
{
value
[
1
]
}
]"
)
instances
.
pop
(
oldest_key
,
None
)
oldest_key
=
next
(
iter
(
instances
),
None
)
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
instance_cv
=
threading
.
Condition
()
sock_cache
:
dict
[
str
,
Any
]
=
{}
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
...
...
@@ -42,38 +94,69 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
if
data
[
"type"
]
==
"P"
:
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
node
=
prefill_instances
.
get
(
data
[
"http_address"
],
None
)
prefill_instances
[
data
[
"http_address"
]]
=
(
data
[
"zmq_address"
],
time
.
time
()
+
DEFAULT_PING_SECONDS
,
)
_remove_oldest_instances
(
prefill_instances
)
elif
data
[
"type"
]
==
"D"
:
global
instance_cv
global
decode_instances
global
decode_cv
with
decode_cv
:
node
=
decode_instances
.
get
(
data
[
"http_address"
],
None
)
decode_instances
[
data
[
"http_address"
]]
=
(
data
[
"zmq_address"
],
time
.
time
()
+
DEFAULT_PING_SECONDS
,
)
_remove_oldest_instances
(
decode_instances
)
if
data
[
"type"
]
==
"P"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add P rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"D"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
])
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add D rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"P_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
prefill_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
p_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
p_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
p_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"D_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
decode_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
d_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
d_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
d_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
data
,
)
return
if
node
is
None
:
print
(
f
"🔵Add [HTTP:
{
data
[
'http_address'
]
}
, ZMQ:
{
data
[
'zmq_address'
]
}
]"
)
zmq_context
=
None
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
...
...
@@ -81,8 +164,11 @@ def start_service_discovery(hostname, port):
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
context
=
zmq
.
Context
()
router_socket
=
context
.
socket
(
zmq
.
ROUTER
)
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global
zmq_context
zmq_context
=
zmq
.
Context
()
router_socket
=
zmq_context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
...
...
@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id):
yield
content
def
unique_id_dispatch
(
prefill_instance
:
str
,
decode_instance
:
str
)
:
global
zmq_context
global
sock_cache
global
router_nccl
global
pd_pair
pd_pair_id
=
prefill_instance
.
zmq_address
+
"_"
+
decode_instance
.
zmq_address
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
return
logger
.
info
(
f
"""[Router] initing pd pair
{
pd_pair_id
}
"""
)
unique_id
=
router_nccl
.
ncclGetUniqueId
()
unique_id
=
bytes
(
unique_id
.
internal
)
rank
=
0
p_rank_num
=
prefill_instance
.
dp_size
*
prefill_instance
.
pp_size
*
prefill_instance
.
tp_size
d_rank_num
=
decode_instance
.
dp_size
*
decode_instance
.
pp_size
*
decode_instance
.
tp_size
world_size
=
p_rank_num
+
d_rank_num
for
dp_rank
in
range
(
prefill_instance
.
dp_size
):
for
pp_rank
in
range
(
prefill_instance
.
pp_size
):
for
tp_rank
in
range
(
prefill_instance
.
tp_size
):
if
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
prefill_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [P] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
for
dp_rank
in
range
(
decode_instance
.
dp_size
):
for
pp_rank
in
range
(
decode_instance
.
pp_size
):
for
tp_rank
in
range
(
decode_instance
.
tp_size
):
if
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
decode_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [D] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
pd_pair
[
pd_pair_id
]
=
unique_id
def
pd_pair_init
():
global
prefill_instances
global
decode_instances
global
pending_prefill_ins
global
pending_decode_ins
global
ready_prefill_ins
global
ready_decode_ins
global
instance_cv
while
True
:
with
instance_cv
:
while
len
(
pending_prefill_ins
)
==
0
and
len
(
pending_decode_ins
)
==
0
:
logger
.
info
(
f
"""[Router] pd_pair_init: waiting for instance_cv"""
)
instance_cv
.
wait
()
logger
.
info
(
f
"""[Router] pd_pair_init: instance_cv finished waiting"""
)
while
pending_prefill_ins
:
p_ins
=
pending_prefill_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
p_ins
}
from pending_prefill_ins"""
)
for
d_ins
in
ready_decode_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_prefill_ins
.
append
(
p_ins
)
pending_prefill_ins
.
remove
(
p_ins
)
while
pending_decode_ins
:
d_ins
=
pending_decode_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
d_ins
}
from pending_decode_ins"""
)
for
p_ins
in
ready_prefill_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_decode_ins
.
append
(
d_ins
)
pending_decode_ins
.
remove
(
d_ins
)
def
start_pd_pair_init
():
_thread
=
threading
.
Thread
(
target
=
pd_pair_init
,
daemon
=
True
)
_thread
.
start
()
return
_thread
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
@
app
.
route
(
"/v1/chat/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
...
...
@@ -129,45 +317,42 @@ async def handle_request():
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
if
"max_completion_tokens"
in
prefill_request
:
prefill_request
[
"max_completion_tokens"
]
=
1
global
count
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_zmq_addr
=
prefill_list
[
count
%
len
(
prefill_list
)]
prefill_zmq_addr
=
prefill_zmq_addr
[
0
]
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_zmq_addr
=
decode_list
[
count
%
len
(
decode_list
)]
decode_zmq_addr
=
decode_zmq_addr
[
0
]
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
print
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"ZMQ:
{
prefill_zmq_addr
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_zmq_addr
}
]"
f
"ZMQ:
{
prefill_
instance
.
zmq_addr
ess
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_
instance
.
zmq_addr
ess
}
]"
)
count
+=
1
request_id
=
(
f
"___prefill_addr_
{
prefill_zmq_addr
}
___decode_addr_"
f
"
{
decode_zmq_addr
}
_
{
random_uuid
()
}
"
f
"___prefill_addr_
{
prefill_
instance
.
zmq_addr
ess
}
___decode_addr_"
f
"
{
decode_
instance
.
zmq_addr
ess
}
_
{
random_uuid
()
}
"
)
# finish prefill
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
{
request
.
path
}
"
,
prefill_request
,
request_id
f
"http://
{
prefill_addr
}
/v1/completions
"
,
prefill_request
,
request_id
):
continue
# return decode
generator
=
forward_request
(
f
"http://
{
decode_addr
}
{
request
.
path
}
"
,
original_request_data
,
request_id
f
"http://
{
decode_addr
}
/v1/completions
"
,
original_request_data
,
request_id
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
...
...
@@ -186,5 +371,7 @@ async def handle_request():
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_pd_pair_init
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t_1
.
join
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment