Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56cc2ac8
Commit
56cc2ac8
authored
Mar 09, 2026
by
xuxz
Browse files
[PD]修改代理
parent
56fef1c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
243 additions
and
56 deletions
+243
-56
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
...gated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
+243
-56
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
View file @
56cc2ac8
...
@@ -4,35 +4,87 @@
...
@@ -4,35 +4,87 @@
import
os
import
os
import
socket
import
socket
import
threading
import
threading
import
time
import
uuid
import
uuid
from
typing
import
Any
import
aiohttp
import
aiohttp
import
msgpack
import
msgpack
import
zmq
import
zmq
from
typing
import
Any
from
quart
import
Quart
,
make_response
,
request
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
collections
import
deque
,
defaultdict
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@
dataclass
class
Instance
:
ins_type
:
str
=
"P"
http_address
:
str
=
""
zmq_address
:
str
=
""
p_unique_id
:
bytes
=
b
""
dp_size
:
int
=
0
pp_size
:
int
=
0
tp_size
:
int
=
0
# [dp, pp, tp] : zmq_address
rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
str
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
# [dp, pp, tp] : global rank
comm_rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
int
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
count
=
0
def
count_rank_table_elements
(
self
):
prefill_instances
:
dict
[
str
,
Any
]
=
{}
# http_address: (zmq_address, stamp)
count
=
0
decode_instances
:
dict
[
str
,
Any
]
=
{}
# http_address: (zmq_address, stamp)
for
first_dict
in
self
.
rank_table
.
values
():
for
second_dict
in
first_dict
.
values
():
count
+=
len
(
second_dict
)
return
count
def
is_ready
(
self
):
world_size
=
self
.
dp_size
*
self
.
pp_size
*
self
.
tp_size
inited_rank
=
self
.
count_rank_table_elements
()
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
# return all_ranks_ready and self.p_unique_id != b""
return
all_ranks_ready
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
prefill_cv
=
threading
.
Condition
()
count
=
0
decode_cv
=
threading
.
Condition
()
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
DEFAULT_PING_SECONDS
=
5
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
ready_prefill_ins
:
list
[
str
]
=
[]
ready_decode_ins
:
list
[
str
]
=
[]
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
def
_remove_oldest_instances
(
instances
:
dict
[
str
,
Any
])
->
None
:
prefill_cv
=
threading
.
Condition
()
oldest_key
=
next
(
iter
(
instances
),
None
)
decode_cv
=
threading
.
Condition
()
while
oldest_key
is
not
None
:
instance_cv
=
threading
.
Condition
()
value
=
instances
[
oldest_key
]
if
value
[
1
]
>
time
.
time
():
break
print
(
f
"🔴Remove [HTTP:
{
oldest_key
}
, ZMQ:
{
value
[
0
]
}
, stamp:
{
value
[
1
]
}
]"
)
instances
.
pop
(
oldest_key
,
None
)
oldest_key
=
next
(
iter
(
instances
),
None
)
sock_cache
:
dict
[
str
,
Any
]
=
{}
def
_listen_for_register
(
poller
,
router_socket
):
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
while
True
:
...
@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket):
...
@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port",
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
data
=
msgpack
.
loads
(
message
)
global
prefill_instances
global
instance_cv
global
decode_instances
if
data
[
"type"
]
==
"P"
:
if
data
[
"type"
]
==
"P"
:
global
prefill_
instance
s
with
instance
_cv
:
global
prefill_cv
if
data
[
"http_address"
]
not
in
prefill_instances
:
with
prefill_
cv
:
prefill_
instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
nod
e
=
prefill_instances
.
get
(
data
[
"http_address"
]
,
None
)
p_instanc
e
=
prefill_instances
[
data
[
"http_address"
]
]
p
refill
_instance
s
[
data
[
"http_address
"
]]
=
(
p_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank
"
]
)
]
=
data
[
"zmq_address"
]
data
[
"zmq_address"
],
if
p_instance
.
is_ready
():
time
.
time
()
+
DEFAULT_PING_SECONDS
,
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
_remove_oldest_instances
(
prefill_instances
)
instance_cv
.
notify
(
)
logger
.
info
(
f
"""[Router] add P rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"D"
:
elif
data
[
"type"
]
==
"D"
:
global
decode_instances
with
instance_cv
:
global
decode_cv
if
data
[
"http_address"
]
not
in
decode_instances
:
with
decode_cv
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
])
node
=
decode_instances
.
get
(
data
[
"http_address"
],
None
)
d_instance
=
decode_instances
[
data
[
"http_address"
]]
decode_instances
[
data
[
"http_address"
]]
=
(
d_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
data
[
"zmq_address"
],
if
d_instance
.
is_ready
():
time
.
time
()
+
DEFAULT_PING_SECONDS
,
pending_decode_ins
.
append
(
d_instance
.
http_address
)
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
_remove_oldest_instances
(
decode_instances
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add D rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"P_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
prefill_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
p_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
p_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
p_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"D_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
decode_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
d_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
d_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
d_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
else
:
else
:
print
(
print
(
"Unexpected, Received message from %s, data: %s"
,
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
remote_address
,
data
,
data
,
)
)
return
if
node
is
None
:
print
(
f
"🔵Add [HTTP:
{
data
[
'http_address'
]
}
, ZMQ:
{
data
[
'zmq_address'
]
}
]"
)
zmq_context
=
None
def
start_service_discovery
(
hostname
,
port
):
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
if
not
hostname
:
hostname
=
socket
.
gethostname
()
hostname
=
socket
.
gethostname
()
if
port
==
0
:
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
raise
ValueError
(
"Port cannot be 0"
)
context
=
zmq
.
Context
()
# context = zmq.Context()
router_socket
=
context
.
socket
(
zmq
.
ROUTER
)
# router_socket = context.socket(zmq.ROUTER)
global
zmq_context
zmq_context
=
zmq
.
Context
()
router_socket
=
zmq_context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
poller
=
zmq
.
Poller
()
...
@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id):
...
@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id):
yield
content
yield
content
def
unique_id_dispatch
(
prefill_instance
:
str
,
decode_instance
:
str
)
:
global
zmq_context
global
sock_cache
global
router_nccl
global
pd_pair
pd_pair_id
=
prefill_instance
.
zmq_address
+
"_"
+
decode_instance
.
zmq_address
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
return
logger
.
info
(
f
"""[Router] initing pd pair
{
pd_pair_id
}
"""
)
unique_id
=
router_nccl
.
ncclGetUniqueId
()
unique_id
=
bytes
(
unique_id
.
internal
)
rank
=
0
p_rank_num
=
prefill_instance
.
dp_size
*
prefill_instance
.
pp_size
*
prefill_instance
.
tp_size
d_rank_num
=
decode_instance
.
dp_size
*
decode_instance
.
pp_size
*
decode_instance
.
tp_size
world_size
=
p_rank_num
+
d_rank_num
for
dp_rank
in
range
(
prefill_instance
.
dp_size
):
for
pp_rank
in
range
(
prefill_instance
.
pp_size
):
for
tp_rank
in
range
(
prefill_instance
.
tp_size
):
if
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
prefill_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [P] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
for
dp_rank
in
range
(
decode_instance
.
dp_size
):
for
pp_rank
in
range
(
decode_instance
.
pp_size
):
for
tp_rank
in
range
(
decode_instance
.
tp_size
):
if
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
decode_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [D] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
pd_pair
[
pd_pair_id
]
=
unique_id
def
pd_pair_init
():
global
prefill_instances
global
decode_instances
global
pending_prefill_ins
global
pending_decode_ins
global
ready_prefill_ins
global
ready_decode_ins
global
instance_cv
while
True
:
with
instance_cv
:
while
len
(
pending_prefill_ins
)
==
0
and
len
(
pending_decode_ins
)
==
0
:
logger
.
info
(
f
"""[Router] pd_pair_init: waiting for instance_cv"""
)
instance_cv
.
wait
()
logger
.
info
(
f
"""[Router] pd_pair_init: instance_cv finished waiting"""
)
while
pending_prefill_ins
:
p_ins
=
pending_prefill_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
p_ins
}
from pending_prefill_ins"""
)
for
d_ins
in
ready_decode_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_prefill_ins
.
append
(
p_ins
)
pending_prefill_ins
.
remove
(
p_ins
)
while
pending_decode_ins
:
d_ins
=
pending_decode_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
d_ins
}
from pending_decode_ins"""
)
for
p_ins
in
ready_prefill_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_decode_ins
.
append
(
d_ins
)
pending_decode_ins
.
remove
(
d_ins
)
def
start_pd_pair_init
():
_thread
=
threading
.
Thread
(
target
=
pd_pair_init
,
daemon
=
True
)
_thread
.
start
()
return
_thread
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
@
app
.
route
(
"/v1/chat/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
async
def
handle_request
():
try
:
try
:
original_request_data
=
await
request
.
get_json
()
original_request_data
=
await
request
.
get_json
()
...
@@ -129,45 +317,42 @@ async def handle_request():
...
@@ -129,45 +317,42 @@ async def handle_request():
prefill_request
=
original_request_data
.
copy
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
prefill_request
[
"max_tokens"
]
=
1
if
"max_completion_tokens"
in
prefill_request
:
prefill_request
[
"max_completion_tokens"
]
=
1
global
count
global
count
global
prefill_instances
global
prefill_instances
global
prefill_cv
global
prefill_cv
with
prefill_cv
:
with
prefill_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_zmq_addr
=
prefill_list
[
count
%
len
(
prefill_list
)]
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
prefill_zmq_addr
=
prefill_zmq_addr
[
0
]
global
decode_instances
global
decode_instances
global
decode_cv
global
decode_cv
with
decode_cv
:
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_zmq_addr
=
decode_list
[
count
%
len
(
decode_list
)]
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
decode_zmq_addr
=
decode_zmq_addr
[
0
]
print
(
print
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"ZMQ:
{
prefill_zmq_addr
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
prefill_
instance
.
zmq_addr
ess
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_zmq_addr
}
]"
f
"ZMQ:
{
decode_
instance
.
zmq_addr
ess
}
]"
)
)
count
+=
1
count
+=
1
request_id
=
(
request_id
=
(
f
"___prefill_addr_
{
prefill_zmq_addr
}
___decode_addr_"
f
"___prefill_addr_
{
prefill_
instance
.
zmq_addr
ess
}
___decode_addr_"
f
"
{
decode_zmq_addr
}
_
{
random_uuid
()
}
"
f
"
{
decode_
instance
.
zmq_addr
ess
}
_
{
random_uuid
()
}
"
)
)
# finish prefill
# finish prefill
async
for
_
in
forward_request
(
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
{
request
.
path
}
"
,
prefill_request
,
request_id
f
"http://
{
prefill_addr
}
/v1/completions
"
,
prefill_request
,
request_id
):
):
continue
continue
# return decode
# return decode
generator
=
forward_request
(
generator
=
forward_request
(
f
"http://
{
decode_addr
}
{
request
.
path
}
"
,
original_request_data
,
request_id
f
"http://
{
decode_addr
}
/v1/completions
"
,
original_request_data
,
request_id
)
)
response
=
await
make_response
(
generator
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
response
.
timeout
=
None
...
@@ -186,5 +371,7 @@ async def handle_request():
...
@@ -186,5 +371,7 @@ async def handle_request():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_pd_pair_init
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t
.
join
()
t_1
.
join
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment