Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
294cc23a
Commit
294cc23a
authored
Jan 04, 2026
by
xiabo
Browse files
解决pd分离非对称切分通信组过多问题
parent
84e5aba2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
473 additions
and
47 deletions
+473
-47
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
...gated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
+241
-19
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
...ving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
+155
-0
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+72
-28
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
View file @
294cc23a
...
...
@@ -9,15 +9,82 @@ import uuid
import
aiohttp
import
msgpack
import
zmq
from
typing
import
Any
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
collections
import
deque
,
defaultdict
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@
dataclass
class
Instance
:
ins_type
:
str
=
"P"
http_address
:
str
=
""
zmq_address
:
str
=
""
p_unique_id
:
bytes
=
b
""
dp_size
:
int
=
0
pp_size
:
int
=
0
tp_size
:
int
=
0
# [dp, pp, tp] : zmq_address
rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
str
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
# [dp, pp, tp] : global rank
comm_rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
int
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
def
count_rank_table_elements
(
self
):
count
=
0
for
first_dict
in
self
.
rank_table
.
values
():
for
second_dict
in
first_dict
.
values
():
count
+=
len
(
second_dict
)
return
count
def
is_ready
(
self
):
world_size
=
self
.
dp_size
*
self
.
pp_size
*
self
.
tp_size
inited_rank
=
self
.
count_rank_table_elements
()
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
# return all_ranks_ready and self.p_unique_id != b""
return
all_ranks_ready
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
count
=
0
prefill_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
decode_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
ready_prefill_ins
:
list
[
str
]
=
[]
ready_decode_ins
:
list
[
str
]
=
[]
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
instance_cv
=
threading
.
Condition
()
sock_cache
:
dict
[
str
,
Any
]
=
{}
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
...
...
@@ -27,16 +94,61 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
if
data
[
"type"
]
==
"P"
:
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
elif
data
[
"type"
]
==
"D"
:
global
instance_cv
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
if
data
[
"type"
]
==
"P"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add P rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"D"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
])
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add D rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"P_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
prefill_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
p_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
p_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
p_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"D_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
decode_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
d_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
d_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
d_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
...
...
@@ -44,6 +156,7 @@ def _listen_for_register(poller, router_socket):
data
,
)
zmq_context
=
None
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
...
...
@@ -51,8 +164,11 @@ def start_service_discovery(hostname, port):
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
context
=
zmq
.
Context
()
router_socket
=
context
.
socket
(
zmq
.
ROUTER
)
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global
zmq_context
zmq_context
=
zmq
.
Context
()
router_socket
=
zmq_context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
...
...
@@ -90,6 +206,109 @@ async def forward_request(url, data, request_id):
yield
content
def
unique_id_dispatch
(
prefill_instance
:
str
,
decode_instance
:
str
)
:
global
zmq_context
global
sock_cache
global
router_nccl
global
pd_pair
pd_pair_id
=
prefill_instance
.
zmq_address
+
"_"
+
decode_instance
.
zmq_address
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
return
logger
.
info
(
f
"""[Router] initing pd pair
{
pd_pair_id
}
"""
)
unique_id
=
router_nccl
.
ncclGetUniqueId
()
unique_id
=
bytes
(
unique_id
.
internal
)
rank
=
0
p_rank_num
=
prefill_instance
.
dp_size
*
prefill_instance
.
pp_size
*
prefill_instance
.
tp_size
d_rank_num
=
decode_instance
.
dp_size
*
decode_instance
.
pp_size
*
decode_instance
.
tp_size
world_size
=
p_rank_num
+
d_rank_num
for
dp_rank
in
range
(
prefill_instance
.
dp_size
):
for
pp_rank
in
range
(
prefill_instance
.
pp_size
):
for
tp_rank
in
range
(
prefill_instance
.
tp_size
):
if
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
prefill_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [P] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
for
dp_rank
in
range
(
decode_instance
.
dp_size
):
for
pp_rank
in
range
(
decode_instance
.
pp_size
):
for
tp_rank
in
range
(
decode_instance
.
tp_size
):
if
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
decode_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [D] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
pd_pair
[
pd_pair_id
]
=
unique_id
def
pd_pair_init
():
global
prefill_instances
global
decode_instances
global
pending_prefill_ins
global
pending_decode_ins
global
ready_prefill_ins
global
ready_decode_ins
global
instance_cv
while
True
:
with
instance_cv
:
while
len
(
pending_prefill_ins
)
==
0
and
len
(
pending_decode_ins
)
==
0
:
logger
.
info
(
f
"""[Router] pd_pair_init: waiting for instance_cv"""
)
instance_cv
.
wait
()
logger
.
info
(
f
"""[Router] pd_pair_init: instance_cv finished waiting"""
)
while
pending_prefill_ins
:
p_ins
=
pending_prefill_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
p_ins
}
from pending_prefill_ins"""
)
for
d_ins
in
ready_decode_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_prefill_ins
.
append
(
p_ins
)
pending_prefill_ins
.
remove
(
p_ins
)
while
pending_decode_ins
:
d_ins
=
pending_decode_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
d_ins
}
from pending_decode_ins"""
)
for
p_ins
in
ready_prefill_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_decode_ins
.
append
(
d_ins
)
pending_decode_ins
.
remove
(
d_ins
)
def
start_pd_pair_init
():
_thread
=
threading
.
Thread
(
target
=
pd_pair_init
,
daemon
=
True
)
_thread
.
start
()
return
_thread
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
...
...
@@ -104,24 +323,25 @@ async def handle_request():
global
prefill_cv
with
prefill_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_
zmq_addr
=
prefill_list
[
count
%
len
(
prefill_list
)]
prefill_addr
,
prefill_
instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_
zmq_addr
=
decode_list
[
count
%
len
(
decode_list
)]
decode_addr
,
decode_
instance
=
decode_list
[
count
%
len
(
decode_list
)]
print
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"ZMQ:
{
prefill_zmq_addr
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_zmq_addr
}
]"
f
"ZMQ:
{
prefill_
instance
.
zmq_addr
ess
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_
instance
.
zmq_addr
ess
}
]"
)
count
+=
1
request_id
=
(
f
"___prefill_addr_
{
prefill_zmq_addr
}
___decode_addr_"
f
"
{
decode_zmq_addr
}
_
{
random_uuid
()
}
"
f
"___prefill_addr_
{
prefill_
instance
.
zmq_addr
ess
}
___decode_addr_"
f
"
{
decode_
instance
.
zmq_addr
ess
}
_
{
random_uuid
()
}
"
)
# finish prefill
...
...
@@ -151,5 +371,7 @@ async def handle_request():
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_pd_pair_init
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t_1
.
join
()
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
0 → 100644
View file @
294cc23a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
socket
import
threading
import
uuid
import
aiohttp
import
msgpack
import
zmq
from
quart
import
Quart
,
make_response
,
request
count
=
0
prefill_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
decode_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
socks
=
dict
(
poller
.
poll
())
if
router_socket
in
socks
:
remote_address
,
message
=
router_socket
.
recv_multipart
()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
if
data
[
"type"
]
==
"P"
:
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
elif
data
[
"type"
]
==
"D"
:
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
data
,
)
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
hostname
=
socket
.
gethostname
()
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
context
=
zmq
.
Context
()
router_socket
=
context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
poller
.
register
(
router_socket
,
zmq
.
POLLIN
)
_listener_thread
=
threading
.
Thread
(
target
=
_listen_for_register
,
args
=
[
poller
,
router_socket
],
daemon
=
True
)
_listener_thread
.
start
()
return
_listener_thread
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
app
=
Quart
(
__name__
)
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
async
def
forward_request
(
url
,
data
,
request_id
):
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
"X-Request-Id"
:
request_id
,
}
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
if
True
:
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
else
:
content
=
await
response
.
read
()
yield
content
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
global
count
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_zmq_addr
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_zmq_addr
=
decode_list
[
count
%
len
(
decode_list
)]
print
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"ZMQ:
{
prefill_zmq_addr
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_zmq_addr
}
]"
)
count
+=
1
request_id
=
(
f
"___prefill_addr_
{
prefill_zmq_addr
}
___decode_addr_"
f
"
{
decode_zmq_addr
}
_
{
random_uuid
()
}
"
)
# finish prefill
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
):
continue
# return decode
generator
=
forward_request
(
f
"http://
{
decode_addr
}
/v1/completions"
,
original_request_data
,
request_id
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
return
response
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
294cc23a
...
...
@@ -123,7 +123,8 @@ class P2pNcclEngine:
if
self
.
remote_tp_size
%
self
.
tp_size
!=
0
:
logger
.
error
(
" the Prefill TP size must be less than or equal to the Decode TP size!!!!"
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
self
.
multiple_machines
=
self
.
config
.
get_from_extra_config
(
"enable_multiple_machines"
,
False
)
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
...
...
@@ -203,12 +204,16 @@ class P2pNcclEngine:
self
.
_listener_thread
.
start
()
self
.
_ping_thread
=
None
#
if
port_offset == 0 and self.proxy_address != ""
:
if
self
.
proxy_address
!=
""
:
if
self
.
multiple_machines
:
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
else
:
if
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping_new
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
logger
.
info
(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
...
...
@@ -267,7 +272,7 @@ class P2pNcclEngine:
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
pd_pair_id
=
p_ip
+
":"
+
str
(
p_port
)
+
"_"
+
remote_ip
+
":"
+
str
(
remote_port
)
# remote_port = 22001
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
remote_addr
=
RemoteAddr
(
pd_pair_id
,
remote_address
,
self
.
rank
+
self
.
pp_size
*
self
.
tp_size
)
...
...
@@ -279,7 +284,7 @@ class P2pNcclEngine:
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
# remote_tp_rank = self.tp_rank * self.multp
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
...
...
@@ -306,7 +311,7 @@ class P2pNcclEngine:
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
send_sync
(
item
)
for
item
in
self
.
get_send_queue_items
(
self
.
_
send_sync
_new
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
...
...
@@ -627,6 +632,9 @@ class P2pNcclEngine:
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
_send_kv_p2p_sync
(
tensor_id
,
kv_layer
,
slot_mapping
,
remote_address
)
else
:
if
self
.
multiple_machines
:
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
else
:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self
.
_send_sync_new
(
tensor_id
,
tensor
,
remote_address
)
...
...
@@ -734,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
#
logger.info(f"""_send_sync_new:{data}""")
logger
.
info
(
f
"""_send_sync_new:
{
data
}
"""
)
sock
.
send
(
msgpack
.
dumps
(
data
))
response
=
sock
.
recv
()
...
...
@@ -830,18 +838,20 @@ class P2pNcclEngine:
return
finished_sending
or
None
,
finished_recving
or
None
def
_ping
(
self
):
# sock = self.context.socket(zmq.DEALER)
# sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
# logger.debug("ping start, zmq_address:%s", self.zmq_address)
# sock.connect(f"tcp://{self.proxy_address}")
# data = {
# "type": "P" if self.config.is_kv_producer else "D",
# "http_address": self.http_address,
# "zmq_address": self.zmq_address
# }
# while True:
# sock.send(msgpack.dumps(data))
# time.sleep(3)
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"zmq_address"
:
self
.
zmq_address
}
while
True
:
sock
.
send
(
msgpack
.
dumps
(
data
))
time
.
sleep
(
3
)
def
_ping_new
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
...
...
@@ -856,7 +866,7 @@ class P2pNcclEngine:
"pp_size"
:
self
.
pp_size
,
"tp_size"
:
self
.
tp_size
}
logger
.
info
(
f
"""_ping data:
{
data
}
"""
)
#
logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
...
...
@@ -867,7 +877,7 @@ class P2pNcclEngine:
"zmq_address"
:
self
.
zmq_address
}
# while True:
logger
.
info
(
f
"""_ping data:
{
data
}
"""
)
#
logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
# time.sleep(3)
...
...
@@ -904,10 +914,44 @@ class P2pNcclEngine:
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
def
get_pp_indices_d
(
self
,
num_hidden_layers
:
int
,
pp_rank
:
int
,
pp_size
:
int
)
->
tuple
[
int
,
int
]:
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION_D
if
partition_list_str
is
not
None
:
try
:
partitions
=
[
int
(
layer
)
for
layer
in
partition_list_str
.
split
(
","
)
]
except
ValueError
as
err
:
raise
ValueError
(
"Invalid partition string: {}"
.
format
(
partition_list_str
))
from
err
if
len
(
partitions
)
!=
pp_size
:
raise
ValueError
(
f
"
{
len
(
partitions
)
=
}
does not match
{
pp_size
=
}
."
)
if
sum
(
partitions
)
!=
num_hidden_layers
:
raise
ValueError
(
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
else
:
layers_per_partition
=
num_hidden_layers
//
pp_size
partitions
=
[
layers_per_partition
for
_
in
range
(
pp_size
)]
if
remaining_layers
:
=
num_hidden_layers
%
pp_size
:
for
i
in
range
(
2
,
remaining_layers
+
2
):
partitions
[
-
i
]
+=
1
logger
.
info
(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable"
,
","
.
join
(
str
(
p
)
for
p
in
partitions
))
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
return
(
start_layer
,
end_layer
)
def
compute_remote_pp_rank
(
self
,
layer_name
:
str
)
->
int
:
current_layer_idx
=
extract_layer_index
(
layer_name
)
for
d_pp_rank
in
range
(
self
.
remote_pp_size
):
start
,
end
=
get_pp_indices
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
start
,
end
=
self
.
get_pp_indices
_d
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if
(
current_layer_idx
==
self
.
total_num_hidden_layers
):
return
self
.
remote_pp_size
-
1
...
...
vllm/envs.py
View file @
294cc23a
...
...
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
None
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION_D
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_NUM_OF_RESERVED_CPU
:
int
=
0
...
...
@@ -487,6 +488,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION_D"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION_D"
,
None
),
# (CPU backend only) CPU key-value cache space.
# default is 4 GiB
"VLLM_CPU_KVCACHE_SPACE"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment