Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4612aad6
Commit
4612aad6
authored
Dec 27, 2025
by
Your Name
Browse files
[P/D][Feat]支持dp并行
parent
cd42bf87
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
874 additions
and
311 deletions
+874
-311
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
...ed_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
+503
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+2
-1
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+89
-106
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+264
-201
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+13
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-0
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
0 → 100644
View file @
4612aad6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
socket
import
threading
import
uuid
import
aiohttp
import
msgpack
import
zmq
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
import
time
from
collections
import
deque
,
defaultdict
import
asyncio
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Request
:
request_id
:
str
p_http_address
:
str
=
""
p_dp_rank
:
int
=
-
1
d_http_address
:
str
=
""
d_dp_rank
:
int
=
-
1
@
dataclass
class
Instance
:
ins_type
:
str
=
"P"
http_address
:
str
=
""
p_unique_id
:
bytes
=
b
""
dp_size
:
int
=
0
pp_size
:
int
=
0
tp_size
:
int
=
0
# [dp, pp, tp] : zmq_address
rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
str
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
# [dp, pp, tp] : global rank
comm_rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
int
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
def
count_rank_table_elements
(
self
):
count
=
0
for
first_dict
in
self
.
rank_table
.
values
():
for
second_dict
in
first_dict
.
values
():
count
+=
len
(
second_dict
)
return
count
def
is_ready
(
self
):
world_size
=
self
.
dp_size
*
self
.
pp_size
*
self
.
tp_size
inited_rank
=
self
.
count_rank_table_elements
()
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
and
self
.
p_unique_id
!=
b
""
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
count
=
0
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
running_requests
:
dict
[
str
,
Request
]
=
{}
healthy_instances
:
dict
[
str
,
float
]
=
{}
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
ready_prefill_ins
:
list
[
str
]
=
[]
ready_decode_ins
:
list
[
str
]
=
[]
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
instance_cv
=
threading
.
Condition
()
request_cv
=
threading
.
Condition
()
health_cv
=
threading
.
Condition
()
request_queue_cv
=
threading
.
Condition
()
request_queue
:
deque
[
list
[
Any
]]
=
deque
()
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
socks
=
dict
(
poller
.
poll
())
if
router_socket
in
socks
:
remote_address
,
message
=
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
global
prefill_instances
global
instance_cv
global
decode_instances
if
data
[
"type"
]
==
"P_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
continue
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
p_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
p_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"D_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
continue
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
d_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
d_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"P_rank"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add P rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"D_rank"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
])
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add D rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"P_unique_id"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
assert
isinstance
(
data
[
"unique_id"
],
bytes
)
p_instance
.
p_unique_id
=
data
[
"unique_id"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add P_unique_id
{
str
(
p_instance
.
p_unique_id
)
}
for
{
p_instance
.
http_address
}
"""
)
elif
data
[
"type"
]
==
"heartbeat"
:
global
healthy_instances
global
health_cv
with
health_cv
:
healthy_instances
[
data
[
"http_address"
]]
=
time
.
time
()
elif
data
[
"type"
]
==
"Req"
:
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global
running_requests
global
request_cv
with
request_cv
:
if
data
[
"request_id"
]
in
running_requests
:
request
=
running_requests
[
data
[
"request_id"
]]
if
data
[
"instance_type"
]
==
"P"
:
request
.
p_http_address
=
data
[
"http_address"
]
request
.
p_dp_rank
=
int
(
data
[
"dp_rank"
])
elif
data
[
"instance_type"
]
==
"D"
:
request
.
d_http_address
=
data
[
"http_address"
]
request
.
d_dp_rank
=
int
(
data
[
"dp_rank"
])
assert
(
request
.
p_dp_rank
>=
0
and
request
.
d_dp_rank
>=
0
)
with
request_queue_cv
:
request_queue
.
append
(
request
)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv
.
notify
()
else
:
if
data
[
"instance_type"
]
==
"P"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
p_http_address
=
data
[
"http_address"
],
p_dp_rank
=
int
(
data
[
"dp_rank"
]))
elif
data
[
"instance_type"
]
==
"D"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
d_http_address
=
data
[
"http_address"
],
d_dp_rank
=
int
(
data
[
"dp_rank"
]))
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
data
,
)
zmq_context
=
None
sock_cache
:
dict
[
str
,
Any
]
=
{}
tp_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
str
]]]
=
{}
tp_comm_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
int
]]]
=
{}
active_p_tp_rank_of_pd_pair
:
dict
[
str
,
set
[
int
]]
=
{}
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
hostname
=
socket
.
gethostname
()
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
global
zmq_context
zmq_context
=
zmq
.
Context
()
router_socket
=
zmq_context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
poller
.
register
(
router_socket
,
zmq
.
POLLIN
)
_listener_thread
=
threading
.
Thread
(
target
=
_listen_for_register
,
args
=
[
poller
,
router_socket
],
daemon
=
True
)
_listener_thread
.
start
()
return
_listener_thread
def
dispatch_to_P
(
request
:
Request
):
global
prefill_instances
global
decode_instances
p_ins
=
prefill_instances
[
request
.
p_http_address
]
d_ins
=
decode_instances
[
request
.
d_http_address
]
global
zmq_context
global
sock_cache
pd_pair_id
=
p_ins
.
http_address
+
"_"
+
d_ins
.
http_address
p_dp_rank
=
request
.
p_dp_rank
d_dp_rank
=
request
.
d_dp_rank
tp_dst_id
=
pd_pair_id
+
"_"
+
str
(
d_dp_rank
)
assert
(
d_ins
.
pp_size
==
1
)
d_pp_rank
=
0
global
tp_mapping_of_pd_pair
global
tp_comm_mapping_of_pd_pair
global
active_p_tp_rank_of_pd_pair
if
tp_dst_id
not
in
active_p_tp_rank_of_pd_pair
:
p_active_tp_rank
=
set
()
p_tp_rank_to_dst
:
dict
[
int
,
list
[
str
]]
=
defaultdict
(
list
)
p_tp_rank_to_dst_comm
:
dict
[
int
,
list
[
int
]]
=
defaultdict
(
list
)
for
d_tp_rank
in
range
(
d_ins
.
tp_size
):
p_tp_rank
=
d_tp_rank
%
p_ins
.
tp_size
p_active_tp_rank
.
add
(
p_tp_rank
)
p_tp_rank_to_dst
[
p_tp_rank
].
append
(
d_ins
.
rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
p_tp_rank_to_dst_comm
[
p_tp_rank
].
append
(
d_ins
.
comm_rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
tp_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
=
p_active_tp_rank
p_active_tp_rank
=
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst
=
tp_mapping_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst_comm
=
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
for
p_tp_rank
in
p_active_tp_rank
:
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_to_transfer"
,
"request_id"
:
request
.
request_id
,
"dst_num"
:
len
(
p_tp_rank_to_dst
[
p_tp_rank
]),
"pd_pair_id"
:
pd_pair_id
,
"remote_address"
:
p_tp_rank_to_dst
[
p_tp_rank
],
"remote_rank"
:
p_tp_rank_to_dst_comm
[
p_tp_rank
],
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
logger
.
info
(
f
"""[Router] dispatch Request
{
request
.
request_id
}
[
{
p_dp_rank
}
,
{
p_pp_rank
}
,
{
p_tp_rank
}
] -> [
{
d_dp_rank
}
,
{
d_pp_rank
}
]"""
)
for
p_tp_rank
in
range
(
p_ins
.
tp_size
):
if
p_tp_rank
not
in
p_active_tp_rank
:
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_not_to_transfer"
,
"request_id"
:
request
.
request_id
,
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
def
dp_dispatch
():
global
request_queue_cv
global
request_queue
while
True
:
with
request_queue_cv
:
while
not
request_queue
:
request_queue_cv
.
wait
()
request
=
request_queue
.
pop
()
dispatch_to_P
(
request
)
def
start_dp_dispatch
():
_thread
=
threading
.
Thread
(
target
=
dp_dispatch
,
daemon
=
True
)
_thread
.
start
()
return
_thread
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
app
=
Quart
(
__name__
)
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
def
unique_id_dispatch
(
prefill_instance
:
Instance
,
decode_instance
:
Instance
)
:
global
zmq_context
global
sock_cache
global
pd_pair
global
router_nccl
pd_pair_id
=
prefill_instance
.
http_address
+
"_"
+
decode_instance
.
http_address
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
return
logger
.
info
(
f
"""[Router] initing pd pair
{
pd_pair_id
}
"""
)
unique_id
=
router_nccl
.
ncclGetUniqueId
()
unique_id
=
bytes
(
unique_id
.
internal
)
p_rank_num
=
prefill_instance
.
dp_size
*
prefill_instance
.
pp_size
*
prefill_instance
.
tp_size
d_rank_num
=
decode_instance
.
dp_size
*
decode_instance
.
pp_size
*
decode_instance
.
tp_size
world_size
=
p_rank_num
+
d_rank_num
rank
=
0
for
dp_rank
in
range
(
prefill_instance
.
dp_size
):
for
pp_rank
in
range
(
prefill_instance
.
pp_size
):
for
tp_rank
in
range
(
prefill_instance
.
tp_size
):
if
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
prefill_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [P] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
for
dp_rank
in
range
(
decode_instance
.
dp_size
):
for
pp_rank
in
range
(
decode_instance
.
pp_size
):
for
tp_rank
in
range
(
decode_instance
.
tp_size
):
if
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
decode_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [D] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
pd_pair
[
pd_pair_id
]
=
unique_id
def
pd_pair_init
():
global
prefill_instances
global
decode_instances
global
pending_prefill_ins
global
pending_decode_ins
global
ready_prefill_ins
global
ready_decode_ins
global
instance_cv
while
True
:
with
instance_cv
:
while
len
(
pending_prefill_ins
)
==
0
and
len
(
pending_decode_ins
)
==
0
:
logger
.
info
(
f
"""[Router] pd_pair_init: waiting for instance_cv"""
)
instance_cv
.
wait
()
logger
.
info
(
f
"""[Router] pd_pair_init: instance_cv finished waiting"""
)
while
pending_prefill_ins
:
p_ins
=
pending_prefill_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
p_ins
}
from pending_prefill_ins"""
)
for
d_ins
in
ready_decode_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_prefill_ins
.
append
(
p_ins
)
pending_prefill_ins
.
remove
(
p_ins
)
while
pending_decode_ins
:
d_ins
=
pending_decode_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
d_ins
}
from pending_decode_ins"""
)
for
p_ins
in
ready_prefill_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_decode_ins
.
append
(
d_ins
)
pending_decode_ins
.
remove
(
d_ins
)
def
start_pd_pair_init
():
_thread
=
threading
.
Thread
(
target
=
pd_pair_init
,
daemon
=
True
)
_thread
.
start
()
return
_thread
async
def
forward_request
(
url
,
data
,
request_id
):
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
"X-Request-Id"
:
request_id
,
}
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
if
True
:
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
else
:
content
=
await
response
.
read
()
yield
content
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
global
count
global
prefill_instances
global
instance_cv
with
instance_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
with
instance_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
# TODO: Init nccl comm : dispatch unique_id among PD pair ranks
global
pd_pair
if
prefill_instance
.
http_address
+
"_"
+
decode_instance
.
http_address
not
in
pd_pair
:
unique_id_dispatch
(
prefill_instance
,
decode_instance
)
logger
.
info
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, 👉 HTTP:
{
decode_addr
}
]"
)
count
+=
1
request_id
=
f
"
{
random_uuid
()
}
"
async
def
run_prefill
():
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
):
pass
prefill_task
=
asyncio
.
create_task
(
run_prefill
())
# return decode
generator
=
forward_request
(
f
"http://
{
decode_addr
}
/v1/completions"
,
original_request_data
,
request_id
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
return
response
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_dp_dispatch
()
t_3
=
start_pd_pair_init
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t_1
.
join
()
t_3
.
join
()
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
4612aad6
...
@@ -54,6 +54,7 @@ class KVConnectorFactory:
...
@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls
,
cls
,
config
:
"VllmConfig"
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
,
)
->
KVConnectorBase_V1
:
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
...
@@ -81,7 +82,7 @@ class KVConnectorFactory:
...
@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
# We build separately to enforce strict separation
return
connector_cls
(
config
,
role
)
return
connector_cls
(
config
,
role
,
dp_rank
)
# Register various connectors here.
# Register various connectors here.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
4612aad6
...
@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
...
@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
import
regex
as
re
import
regex
as
re
import
torch
import
torch
import
os
import
os
from
vllm
import
envs
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
import
(
P2pNcclEngine
)
P2pNcclEngine
,
RemoteAddr
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.parallel_state
import
get_world_group
,
get_dp_group
,
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
import
zmq
import
msgpack
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
...
@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class
P2pNcclConnector
(
KVConnectorBase_V1
):
class
P2pNcclConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
...
@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
role
==
KVConnectorRole
.
WORKER
else
0
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
local_rank
=
self
.
_local_rank
,
local_rank
=
self
.
_local_rank
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
hostname
=
""
,
port_offset
=
self
.
_rank
,
dp_rank
=
self
.
_dp_rank
,
pp_rank
=
self
.
_pp_rank
,
tp_rank
=
self
.
_tp_rank
,
dp_size
=
self
.
_dp_size
,
pp_size
=
self
.
_pp_size
,
tp_size
=
self
.
_tp_size
)
if
role
==
KVConnectorRole
.
WORKER
else
None
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
multiple_machines
=
1
if
self
.
num_card
>
8
else
0
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
if
self
.
is_producer
and
self
.
multiple_machines
==
1
:
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
self
.
ip_map
=
{}
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
...
@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
dp_rank
=
dp_rank
proxy_ip
=
self
.
config
.
get_from_extra_config
(
"proxy_ip"
,
""
)
proxy_port
=
self
.
config
.
get_from_extra_config
(
"proxy_port"
,
""
)
if
proxy_ip
==
""
or
proxy_port
==
""
:
self
.
proxy_address
=
""
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
self
.
context
=
zmq
.
Context
()
req_sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
req_sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"
{
self
.
http_address
}
_rank
{
self
.
dp_rank
}
"
)
req_sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
self
.
req_sock
=
req_sock
def
get_ip_value
(
self
,
key
):
def
get_ip_value
(
self
,
key
):
return
self
.
ip_map
.
get
(
key
)
return
self
.
ip_map
.
get
(
key
)
def
register_req
(
self
,
request_id
:
str
)
:
data
=
{
"type"
:
"Req"
,
"instance_type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"request_id"
:
request_id
,
"dp_rank"
:
self
.
dp_rank
}
self
.
req_sock
.
send
(
msgpack
.
dumps
(
data
))
# ==============================
# ==============================
# Worker-side methods
# Worker-side methods
# ==============================
# ==============================
...
@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2
,
num_pages
*
page_size
,
-
1
)
2
,
num_pages
*
page_size
,
-
1
)
inject_start_index
=
0
inject_start_index
=
0
for
num
in
range
(
self
.
p2p_nccl_engine
.
tensor_split_num
):
req_layer
=
f
"
{
request
.
request_id
}
#
{
layer_name
}
"
with
self
.
p2p_nccl_engine
.
recv_store_cv
:
while
req_layer
not
in
self
.
p2p_nccl_engine
.
recv_split_nums
:
self
.
p2p_nccl_engine
.
recv_store_cv
.
wait
()
split_num
=
self
.
p2p_nccl_engine
.
recv_split_nums
.
get
(
req_layer
)
for
num
in
range
(
split_num
):
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
))
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
))
...
@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
# request.slot_mapping, request.request_id)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
)
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
...
@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
assert
self
.
p2p_nccl_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
...
@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
for
request
in
connector_metadata
.
requests
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
slot_mapping
=
request
.
slot_mapping
slot_mapping
=
request
.
slot_mapping
if
request
.
slot_mapping_device
is
None
:
if
request
.
slot_mapping_device
is
None
:
request
.
slot_mapping_device
=
\
request
.
slot_mapping_device
=
\
...
@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
slot_mapping
=
request
.
slot_mapping_device
slot_mapping
=
request
.
slot_mapping_device
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
.
record
()
tbo_evt
.
record
()
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
\
pending
=
False
self
.
parallel_config
.
pipeline_parallel_size
with
self
.
p2p_nccl_engine
.
req_status_cv
:
if
(
self
.
pp_size
==
1
):
if
request_id
not
in
self
.
p2p_nccl_engine
.
req_status
:
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
pending
=
True
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
if
pending
:
elif
(
self
.
pp_size
==
2
):
self
.
p2p_nccl_engine
.
pending_tensor
(
request_id
,
layer_name
,
if
(
pp_rank
==
0
):
(
kv_layer
,
slot_mapping
),
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
else
:
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
req_data
=
self
.
p2p_nccl_engine
.
req_status
[
request_id
]
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
),
tbo_evt
)
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
else
:
for
i
in
range
(
req_data
.
dst_num
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
),
tbo_evt
)
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
i
),
tbo_evt
)
(
kv_layer
,
slot_mapping
),
remote_addr
,
tbo_evt
)
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
# self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
# (kv_layer, slot_mapping), remote_address, tbo_evt)
else
:
else
:
for
request
in
connector_metadata
.
requests
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
pending
=
False
)
%
self
.
parallel_config
.
pipeline_parallel_size
with
self
.
p2p_nccl_engine
.
req_status_cv
:
if
(
self
.
multiple_machines_p
and
self
.
multiple_machines_d
):
if
request_id
not
in
self
.
p2p_nccl_engine
.
req_status
:
ip_second
=
self
.
get_ip_value
(
ip
)
pending
=
True
if
(
self
.
pp_size
==
1
):
if
pending
:
if
self
.
_rank
<
8
:
self
.
p2p_nccl_engine
.
pending_tensor
(
request_id
,
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
)
kv_cache
,
remote_address
)
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
else
:
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
+
8
))
req_data
=
self
.
p2p_nccl_engine
.
req_status
[
request_id
]
elif
(
self
.
pp_size
==
2
):
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
if
(
pp_rank
==
0
):
for
i
in
range
(
req_data
.
dst_num
):
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_addr
)
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
else
:
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
if
(
self
.
pp_size
==
2
):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
else
:
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
self
.
p2p_nccl_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
is_mla
)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
logger
.
error
(
"Error: not support!!!!!!"
)
def
wait_for_save
(
self
):
def
wait_for_save
(
self
):
pass
pass
# if self.is_producer:
# if self.is_producer:
...
@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens
=
(
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
# assert req_id in self.chunked_prefill
assert
req_id
in
self
.
chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
block_ids
=
new_block_ids
[
0
]
block_ids
=
new_block_ids
[
0
]
if
not
resumed_from_preemption
:
if
not
resumed_from_preemption
:
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
4612aad6
...
@@ -6,14 +6,14 @@ import os
...
@@ -6,14 +6,14 @@ import os
import
threading
import
threading
import
time
import
time
import
typing
import
typing
from
collections
import
deque
from
collections
import
deque
,
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
dataclasses
import
dataclass
,
field
import
msgpack
import
msgpack
import
torch
import
torch
import
zmq
import
zmq
import
regex
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
...
@@ -24,11 +24,6 @@ from vllm.utils import current_stream, get_ip
...
@@ -24,11 +24,6 @@ from vllm.utils import current_stream, get_ip
from
vllm
import
envs
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
dataclasses
import
dataclass
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.config
import
ModelConfig
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -36,11 +31,6 @@ logger = logging.getLogger(__name__)
...
@@ -36,11 +31,6 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB
=
32
DEFAULT_MEM_POOL_SIZE_GB
=
32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@
contextmanager
@
contextmanager
def
set_p2p_nccl_context
(
num_channels
:
str
):
def
set_p2p_nccl_context
(
num_channels
:
str
):
...
@@ -71,44 +61,50 @@ def set_p2p_nccl_context(num_channels: str):
...
@@ -71,44 +61,50 @@ def set_p2p_nccl_context(num_channels: str):
else
:
else
:
os
.
environ
.
pop
(
var
,
None
)
os
.
environ
.
pop
(
var
,
None
)
@
dataclass
class
ReqKVDest
:
dst_num
:
int
=
0
pd_pair_id
:
str
=
""
zmq_address_and_comm_rank
:
list
[
tuple
[
str
,
int
]]
=
field
(
default_factory
=
list
)
@
dataclass
class
RemoteAddr
:
pd_pair_id
:
str
=
""
zmq_address
:
str
=
""
comm_rank
:
int
=
0
class
P2pNcclEngine
:
class
P2pNcclEngine
:
def
__init__
(
self
,
def
__init__
(
self
,
local_rank
:
int
,
local_rank
:
int
,
port_offset
:
int
,
config
:
KVTransferConfig
,
config
:
KVTransferConfig
,
model_config
:
ModelConfig
,
hostname
:
str
=
""
,
port_offset
:
int
=
0
,
dp_rank
:
int
=
0
,
pp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
dp_size
:
int
=
0
,
pp_size
:
int
=
0
,
tp_size
:
int
=
0
,
library_path
:
Optional
[
str
]
=
None
)
->
None
:
library_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
config
=
config
self
.
config
=
config
self
.
model_config
=
model_config
self
.
rank
=
port_offset
self
.
rank
=
port_offset
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
self
.
dp_rank
=
dp_rank
self
.
pp_rank
=
pp_rank
self
.
tp_rank
=
tp_rank
self
.
dp_size
=
dp_size
self
.
pp_size
=
pp_size
self
.
tp_size
=
tp_size
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
if
not
hostname
:
"num_hidden_layers"
,
0
)
hostname
=
get_ip
()
self
.
pp_rank
=
get_pp_group
().
rank_in_group
self
.
tp_rank
=
get_tp_group
().
rank_in_group
self
.
pp_size
=
get_pp_group
().
world_size
self
.
tp_size
=
get_tp_group
().
world_size
if
config
.
is_kv_producer
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
1
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
1
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
if
self
.
remote_tp_size
%
self
.
tp_size
!=
0
:
logger
.
error
(
" the Prefill TP size must be less than or equal to the Decode TP size!!!!"
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
raise
ValueError
(
"Port cannot be 0"
)
self
.
_hostname
=
get_ip
()
self
.
_hostname
=
hostname
self
.
_port
=
port
self
.
_port
=
port
# Each card corresponds to a ZMQ address.
# Each card corresponds to a ZMQ address.
...
@@ -116,7 +112,7 @@ class P2pNcclEngine:
...
@@ -116,7 +112,7 @@ class P2pNcclEngine:
# The `http_port` must be consistent with the port of OpenAI.
# The `http_port` must be consistent with the port of OpenAI.
self
.
http_address
=
(
self
.
http_address
=
(
f
"
{
self
.
_hostname
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
# If `proxy_ip` or `proxy_port` is `""`,
# If `proxy_ip` or `proxy_port` is `""`,
...
@@ -128,16 +124,27 @@ class P2pNcclEngine:
...
@@ -128,16 +124,27 @@ class P2pNcclEngine:
else
:
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
kv_cache_layer_num
=
0
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
router_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
router_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
router_socket
.
setsockopt
(
zmq
.
RCVHWM
,
10000
)
self
.
router_socket
.
setsockopt
(
zmq
.
SNDHWM
,
5000
)
self
.
router_socket
.
setsockopt
(
zmq
.
LINGER
,
0
)
self
.
router_socket
.
setsockopt
(
zmq
.
ROUTER_MANDATORY
,
1
)
self
.
router_socket
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
self
.
router_socket
.
bind
(
f
"tcp://
{
self
.
zmq_address
}
"
)
self
.
router_socket
.
bind
(
f
"tcp://
{
self
.
zmq_address
}
"
)
self
.
poller
=
zmq
.
Poller
()
self
.
poller
=
zmq
.
Poller
()
self
.
poller
.
register
(
self
.
router_socket
,
zmq
.
POLLIN
)
self
.
poller
.
register
(
self
.
router_socket
,
zmq
.
POLLIN
)
self
.
req_status
:
dict
[
str
,
ReqKVDest
]
=
{}
self
.
req_status_cv
=
threading
.
Condition
()
self
.
send_store_cv
=
threading
.
Condition
()
self
.
send_store_cv
=
threading
.
Condition
()
self
.
send_queue_cv
=
threading
.
Condition
()
self
.
send_queue_cv
=
threading
.
Condition
()
self
.
recv_store_cv
=
threading
.
Condition
()
self
.
recv_store_cv
=
threading
.
Condition
()
self
.
pending_queue_cv
=
threading
.
Condition
()
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
...
@@ -145,6 +152,7 @@ class P2pNcclEngine:
...
@@ -145,6 +152,7 @@ class P2pNcclEngine:
self
.
p2p_async_kv_tokens
=
envs
.
VLLM_P2P_BUF_TOKENS
self
.
p2p_async_kv_tokens
=
envs
.
VLLM_P2P_BUF_TOKENS
self
.
p2p_async_buf
=
None
self
.
p2p_async_buf
=
None
self
.
tensor_split_num
:
int
=
0
self
.
tensor_split_num
:
int
=
0
self
.
recv_split_nums
:
dict
[
str
,
int
]
=
{}
mem_pool_size_gb
=
self
.
config
.
get_from_extra_config
(
mem_pool_size_gb
=
self
.
config
.
get_from_extra_config
(
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
)
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
)
...
@@ -161,11 +169,16 @@ class P2pNcclEngine:
...
@@ -161,11 +169,16 @@ class P2pNcclEngine:
# PUT or PUT_ASYNC
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
# tensor_id: torch.Tensor
self
.
send_queue
:
deque
[
list
[
Any
]]
=
deque
()
self
.
send_queue
:
deque
[
list
[
Any
]]
=
deque
()
self
.
pending_queue
:
dict
[
str
,
list
[
list
[
Any
]]]
=
defaultdict
(
list
)
self
.
requests_to_release
:
dict
[
str
,
bool
]
=
{}
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_send_async
,
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_send_async
,
daemon
=
True
)
daemon
=
True
)
self
.
_send_thread
.
start
()
self
.
_send_thread
.
start
()
self
.
_pending_check_thread
=
threading
.
Thread
(
target
=
self
.
_pending_check
,
daemon
=
True
)
self
.
_pending_check_thread
.
start
()
# tensor_id: torch.Tensor/(addr, dtype, shape)
# tensor_id: torch.Tensor/(addr, dtype, shape)
self
.
recv_store
:
dict
[
str
,
Any
]
=
{}
self
.
recv_store
:
dict
[
str
,
Any
]
=
{}
...
@@ -184,7 +197,7 @@ class P2pNcclEngine:
...
@@ -184,7 +197,7 @@ class P2pNcclEngine:
self
.
_listener_thread
.
start
()
self
.
_listener_thread
.
start
()
self
.
_ping_thread
=
None
self
.
_ping_thread
=
None
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
if
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
daemon
=
True
)
self
.
_ping_thread
.
start
()
self
.
_ping_thread
.
start
()
...
@@ -198,92 +211,24 @@ class P2pNcclEngine:
...
@@ -198,92 +211,24 @@ class P2pNcclEngine:
def
_create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
def
_create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
if
remote_address
not
in
self
.
socks
:
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
sock
.
setsockopt
(
zmq
.
SNDHWM
,
10000
)
sock
.
setsockopt
(
zmq
.
RCVHWM
,
5000
)
sock
.
setsockopt
(
zmq
.
LINGER
,
0
)
sock
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"P-
{
self
.
zmq_address
}
"
)
sock
.
connect
(
f
"tcp://
{
remote_address
}
"
)
sock
.
connect
(
f
"tcp://
{
remote_address
}
"
)
self
.
socks
[
remote_address
]
=
sock
self
.
socks
[
remote_address
]
=
sock
if
remote_address
in
self
.
comms
:
logger
.
info
(
"👋comm exists, remote_address:%s, comms:%s"
,
remote_address
,
self
.
comms
)
return
sock
,
self
.
comms
[
remote_address
]
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
data
=
{
"cmd"
:
"NEW"
,
"unique_id"
:
bytes
(
unique_id
.
internal
)}
sock
.
send
(
msgpack
.
dumps
(
data
))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
0
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👉%s, MyRank: %s"
,
self
.
zmq_address
,
remote_address
,
rank
)
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
return
self
.
socks
[
remote_address
]
def
get_send_queue_items
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
)
->
list
[
any
]:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
remote_ip
,
remote_port
=
self
.
parse_request_id
(
request_id
,
True
)
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
return
[(
tensor_id
,
remote_address
,
tensor
)]
if
not
is_mla
:
logger
.
error
(
" P2PNCCL only support mla model symmetric PP/TP!!!!"
)
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
up_down
=
1
# remote_tp_rank = self.tp_rank * self.multp
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
remote_port_offset
=
remote_pp_rank
*
self
.
remote_tp_size
+
d_tp_rank
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
remote_port_offset
)
logger
.
debug
(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)"
,
tensor_id
,
tensor
.
shape
,
self
.
pp_rank
,
self
.
tp_rank
,
remote_address
,
remote_pp_rank
,
self
.
rank
*
mul_tp
+
self
.
rank
)
items
.
append
([
tensor_id
,
remote_address
,
tensor
])
return
items
def
send_tensor_new
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
=
False
,
)
->
bool
:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
send_sync
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
):
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
if
self
.
send_type
==
"GET"
:
logger
.
error
(
" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!"
)
def
send_tensor
(
def
send_tensor
(
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
remote_address
:
typing
.
Optional
[
RemoteAddr
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
)
->
bool
:
if
remote_address
is
None
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
...
@@ -310,7 +255,7 @@ class P2pNcclEngine:
...
@@ -310,7 +255,7 @@ class P2pNcclEngine:
logger
.
info
(
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
send_store
[
tensor_id
]
=
tensor
...
@@ -318,7 +263,7 @@ class P2pNcclEngine:
...
@@ -318,7 +263,7 @@ class P2pNcclEngine:
logger
.
debug
(
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
...
@@ -328,13 +273,14 @@ class P2pNcclEngine:
...
@@ -328,13 +273,14 @@ class P2pNcclEngine:
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
st
r
]
=
None
,
remote_address
:
typing
.
Optional
[
RemoteAdd
r
]
=
None
,
tbo_evt
=
None
,
tbo_evt
=
None
,
)
->
bool
:
)
->
bool
:
if
remote_address
is
None
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store_cv
.
notify
()
# self.recv_store_cv.notify()
self
.
recv_store_cv
.
notify_all
()
return
True
return
True
else
:
else
:
if
self
.
send_type
==
"PUT"
:
if
self
.
send_type
==
"PUT"
:
...
@@ -343,7 +289,7 @@ class P2pNcclEngine:
...
@@ -343,7 +289,7 @@ class P2pNcclEngine:
with
self
.
send_queue_cv
:
with
self
.
send_queue_cv
:
kv_layer
,
slot_mapping
=
tensor
# tesor (kv_layer, slot_mapping)
kv_layer
,
slot_mapping
=
tensor
# tesor (kv_layer, slot_mapping)
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
])
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
])
self
.
send_queue_cv
.
notify
()
self
.
send_queue_cv
.
notify
_all
()
else
:
# GET
else
:
# GET
with
self
.
send_store_cv
:
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
...
@@ -357,7 +303,7 @@ class P2pNcclEngine:
...
@@ -357,7 +303,7 @@ class P2pNcclEngine:
logger
.
info
(
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
send_store
[
tensor_id
]
=
tensor
...
@@ -365,12 +311,62 @@ class P2pNcclEngine:
...
@@ -365,12 +311,62 @@ class P2pNcclEngine:
logger
.
debug
(
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
return
True
# TODO : support p2p async
def
pending_tensor
(
self
,
reuqest_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
tbo_evt
=
None
,
)
->
bool
:
with
self
.
pending_queue_cv
:
self
.
pending_queue
[
reuqest_id
].
append
([
layer_name
,
tensor
,
tbo_evt
])
self
.
pending_queue_cv
.
notify
()
return
True
def
unpending_tensor
(
self
,
request_id
:
str
,
req_data
:
ReqKVDest
,
)
->
bool
:
with
self
.
pending_queue_cv
:
tensor_list
=
self
.
pending_queue
.
pop
(
request_id
)
if
request_id
in
self
.
requests_to_release
:
self
.
requests_to_release
[
request_id
]
=
True
logger
.
info
(
"[%d] unpending request: %s"
,
self
.
rank
,
request_id
)
if
req_data
.
dst_num
<=
0
:
return
False
for
layer_name
,
tensor
,
tbo_evt
in
tensor_list
:
for
i
in
range
(
req_data
.
dst_num
)
:
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
,
tbo_evt
)
else
:
self
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
)
return
True
def
_pending_check
(
self
)
:
while
True
:
with
self
.
pending_queue_cv
:
while
not
self
.
pending_queue
:
self
.
pending_queue_cv
.
wait
()
pending_queue
=
self
.
pending_queue
.
copy
()
for
request_id
in
pending_queue
:
with
self
.
req_status_cv
:
if
request_id
not
in
self
.
req_status
:
continue
req_data
=
self
.
req_status
[
request_id
]
assert
(
len
(
req_data
.
zmq_address_and_comm_rank
)
==
req_data
.
dst_num
)
self
.
unpending_tensor
(
request_id
,
req_data
)
def
recv_tensor
(
def
recv_tensor
(
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
...
@@ -407,6 +403,7 @@ class P2pNcclEngine:
...
@@ -407,6 +403,7 @@ class P2pNcclEngine:
self
.
_create_connect
(
remote_address
)
self
.
_create_connect
(
remote_address
)
sock
=
self
.
socks
[
remote_address
]
sock
=
self
.
socks
[
remote_address
]
# TODO: self.comms has changed along with PUT mode
comm
,
rank
=
self
.
comms
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
data
=
{
"cmd"
:
"GET"
,
"tensor_id"
:
tensor_id
}
data
=
{
"cmd"
:
"GET"
,
"tensor_id"
:
tensor_id
}
...
@@ -429,26 +426,23 @@ class P2pNcclEngine:
...
@@ -429,26 +426,23 @@ class P2pNcclEngine:
def
_listen_for_requests
(
self
):
def
_listen_for_requests
(
self
):
while
True
:
while
True
:
socks
=
dict
(
self
.
poller
.
poll
())
socks
=
dict
(
self
.
poller
.
poll
(
5000
))
if
self
.
router_socket
in
socks
:
if
self
.
router_socket
in
socks
:
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
if
data
[
"cmd"
]
==
"NEW"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
logger
.
info
(
f
"unexpected message from
{
remote_address
.
decode
()
}
"
)
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
1
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
.
decode
()]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
rank
)
elif
data
[
"cmd"
]
==
"PUT"
:
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
if
"tensor_split_num"
in
data
:
self
.
tensor_split_num
=
data
[
"tensor_split_num"
]
# self.tensor_split_num = data["tensor_split_num"]
parts
=
tensor_id
.
split
(
'#'
)
request_id
=
parts
[
0
]
layer_name
=
parts
[
1
]
req_layer
=
f
"
{
request_id
}
#
{
layer_name
}
"
self
.
recv_split_nums
[
req_layer
]
=
data
[
"tensor_split_num"
]
with
self
.
recv_store_cv
:
self
.
recv_store_cv
.
notify_all
()
try
:
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
tensor
=
torch
.
empty
(
data
[
"shape"
],
...
@@ -457,8 +451,8 @@ class P2pNcclEngine:
...
@@ -457,8 +451,8 @@ class P2pNcclEngine:
device
=
self
.
device
)
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
(
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"0"
])
[
remote_address
,
b
"0"
])
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()
]
comm
,
rank
=
self
.
comms
[
data
[
"pd_pair_id"
]
]
self
.
_recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
self
.
_recv
(
comm
,
tensor
,
int
(
data
[
"comm_rank"
])
,
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
>
self
.
buffer_size_threshold
):
...
@@ -480,7 +474,8 @@ class P2pNcclEngine:
...
@@ -480,7 +474,8 @@ class P2pNcclEngine:
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
#self.recv_store_cv.notify()
self
.
recv_store_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"GET"
:
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
tensor_id
=
data
[
"tensor_id"
]
...
@@ -503,9 +498,35 @@ class P2pNcclEngine:
...
@@ -503,9 +498,35 @@ class P2pNcclEngine:
[
remote_address
,
msgpack
.
dumps
(
data
)])
[
remote_address
,
msgpack
.
dumps
(
data
)])
if
data
[
"ret"
]
==
0
:
if
data
[
"ret"
]
==
0
:
# TODO: self.comms has changed along with PUT mode
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
self
.
send_stream
)
elif
data
[
"cmd"
]
==
"comm_init"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
int
(
data
[
"rank"
])
world_size
=
int
(
data
[
"world_size"
])
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
world_size
,
unique_id
,
rank
)
self
.
comms
[
data
[
"pd_pair_id"
]]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
data
[
"pd_pair_id"
],
rank
)
elif
data
[
"cmd"
]
==
"req_to_transfer"
:
with
self
.
req_status_cv
:
assert
(
data
[
"request_id"
]
not
in
self
.
req_status
)
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
int
(
data
[
"dst_num"
]),
pd_pair_id
=
data
[
"pd_pair_id"
],
zmq_address_and_comm_rank
=
list
(
zip
(
data
[
"remote_address"
],
data
[
"remote_rank"
])))
self
.
req_status_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"req_not_to_transfer"
:
with
self
.
req_status_cv
:
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
0
)
self
.
req_status_cv
.
notify_all
()
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"🚧Unexpected, Received message from %s, data:%s"
,
"🚧Unexpected, Received message from %s, data:%s"
,
...
@@ -533,7 +554,7 @@ class P2pNcclEngine:
...
@@ -533,7 +554,7 @@ class P2pNcclEngine:
else
:
else
:
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
if
not
self
.
send_queue
:
if
not
self
.
send_queue
:
self
.
send_queue_cv
.
notify
()
self
.
send_queue_cv
.
notify
_all
()
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
_send_kv_p2p_sync
(
tensor_id
,
kv_layer
,
slot_mapping
,
remote_address
)
self
.
_send_kv_p2p_sync
(
tensor_id
,
kv_layer
,
slot_mapping
,
remote_address
)
...
@@ -551,12 +572,13 @@ class P2pNcclEngine:
...
@@ -551,12 +572,13 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d"
,
duration
*
1000
,
self
.
rank
)
" to be empty, rank:%d"
,
duration
*
1000
,
self
.
rank
)
# TODO : support p2p async
def
_send_kv_p2p_sync
(
self
,
tensor_id
:
str
,
kv_layer
:
torch
.
Tensor
,
def
_send_kv_p2p_sync
(
self
,
tensor_id
:
str
,
kv_layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
remote_address
:
str
)
->
bool
:
slot_mapping
:
torch
.
Tensor
,
remote_address
:
str
)
->
bool
:
if
remote_address
not
in
self
.
socks
:
if
remote_address
.
zmq_address
not
in
self
.
socks
:
self
.
_create_connect
(
remote_address
)
self
.
_create_connect
(
remote_address
.
zmq_address
)
sock
=
self
.
socks
[
remote_address
]
sock
=
self
.
socks
[
remote_address
.
zmq_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
.
pd_pair_id
]
is_mla
=
(
kv_layer
.
ndim
==
3
)
is_mla
=
(
kv_layer
.
ndim
==
3
)
hidden_dim
=
kv_layer
.
shape
[
-
1
]
hidden_dim
=
kv_layer
.
shape
[
-
1
]
...
@@ -600,20 +622,22 @@ class P2pNcclEngine:
...
@@ -600,20 +622,22 @@ class P2pNcclEngine:
"pack_idx"
:
pack_idx
,
"pack_idx"
:
pack_idx
,
"tensor_split_num"
:
pack_num
,
"tensor_split_num"
:
pack_num
,
"shape"
:
tx_shape
,
"shape"
:
tx_shape
,
"dtype"
:
str
(
kv_layer
.
dtype
).
replace
(
"torch."
,
""
)
"dtype"
:
str
(
kv_layer
.
dtype
).
replace
(
"torch."
,
""
),
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
}
sock
.
send
(
msgpack
.
dumps
(
header
))
sock
.
send
(
msgpack
.
dumps
(
header
))
response
=
sock
.
recv
()
response
=
sock
.
recv
()
if
response
!=
b
"0"
:
if
response
!=
b
"0"
:
logger
.
error
(
logger
.
error
(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s"
,
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s"
,
self
.
zmq_address
,
remote_address
,
rank
,
self
.
zmq_address
,
remote_address
.
zmq_address
,
rank
,
tuple
(
send_tensor
.
shape
),
send_tensor
.
element_size
()
*
send_tensor
.
numel
()
/
1024
**
3
,
tuple
(
send_tensor
.
shape
),
send_tensor
.
element_size
()
*
send_tensor
.
numel
()
/
1024
**
3
,
response
.
decode
()
response
.
decode
()
)
)
return
False
return
False
self
.
_send
(
comm
,
send_tensor
,
r
ank
^
1
,
self
.
send_stream
)
self
.
_send
(
comm
,
send_tensor
,
r
emote_address
.
comm_rank
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_have_sent_tensor_id
(
tensor_id
)
self
.
_have_sent_tensor_id
(
tensor_id
)
...
@@ -624,20 +648,22 @@ class P2pNcclEngine:
...
@@ -624,20 +648,22 @@ class P2pNcclEngine:
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
st
r
]
=
None
,
remote_address
:
typing
.
Optional
[
RemoteAdd
r
]
=
None
,
)
->
bool
:
)
->
bool
:
if
remote_address
is
None
:
if
remote_address
is
None
:
return
False
return
False
if
remote_address
not
in
self
.
socks
:
if
remote_address
.
zmq_address
not
in
self
.
socks
:
self
.
_create_connect
(
remote_address
)
self
.
_create_connect
(
remote_address
.
zmq_address
)
sock
=
self
.
socks
[
remote_address
]
sock
=
self
.
socks
[
remote_address
.
zmq_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
.
pd_pair_id
]
data
=
{
data
=
{
"cmd"
:
"PUT"
,
"cmd"
:
"PUT"
,
"tensor_id"
:
tensor_id
,
"tensor_id"
:
tensor_id
,
"shape"
:
tensor
.
shape
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
),
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
}
sock
.
send
(
msgpack
.
dumps
(
data
))
sock
.
send
(
msgpack
.
dumps
(
data
))
...
@@ -646,12 +672,12 @@ class P2pNcclEngine:
...
@@ -646,12 +672,12 @@ class P2pNcclEngine:
logger
.
error
(
logger
.
error
(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s"
,
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s"
,
self
.
zmq_address
,
remote_address
,
rank
,
data
,
tensor
.
shape
,
self
.
zmq_address
,
remote_address
.
zmq_address
,
rank
,
data
,
tensor
.
shape
,
tensor
.
element_size
()
*
tensor
.
numel
()
/
1024
**
3
,
tensor
.
element_size
()
*
tensor
.
numel
()
/
1024
**
3
,
response
.
decode
())
response
.
decode
())
return
False
return
False
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
r
ank
^
1
,
self
.
send_stream
)
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
r
emote_address
.
comm_rank
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_have_sent_tensor_id
(
tensor_id
)
self
.
_have_sent_tensor_id
(
tensor_id
)
...
@@ -673,20 +699,46 @@ class P2pNcclEngine:
...
@@ -673,20 +699,46 @@ class P2pNcclEngine:
"""
"""
# Clear the buffer upon request completion.
# Clear the buffer upon request completion.
# for request_id in finished_req_ids:
# for layer_name in forward_context.no_compile_layers:
# tensor_id = request_id + "#" + layer_name
# if tensor_id in self.recv_store:
# with self.recv_store_cv:
# tensor = self.recv_store.pop(tensor_id, None)
# self.send_request_id_to_tensor_ids.pop(
# request_id, None)
# self.recv_request_id_to_tensor_ids.pop(
# request_id, None)
# addr = 0
# if isinstance(tensor, tuple):
# addr, _, _ = tensor
# self.pool.free(addr)
requests_to_release
:
list
[
str
]
=
[]
with
self
.
pending_queue_cv
:
for
request_id
,
release
in
self
.
requests_to_release
.
items
():
if
release
:
requests_to_release
.
append
(
request_id
)
self
.
requests_to_release
.
pop
(
request_id
)
for
request_id
in
finished_req_ids
:
for
request_id
in
finished_req_ids
:
for
layer_name
in
forward_context
.
no_compile_layers
:
with
self
.
pending_queue_cv
:
tensor_id
=
request_id
+
"#"
+
layer_name
if
request_id
in
self
.
pending_queue
:
if
tensor_id
in
self
.
recv_store
:
self
.
requests_to_release
[
request_id
]
=
False
logger
.
info
(
"[%d] pending request: %s"
,
self
.
rank
,
request_id
)
continue
requests_to_release
.
append
(
request_id
)
for
request_id
in
requests_to_release
:
ids
=
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
set
())
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
for
tensor_id
in
ids
:
tensor
=
self
.
recv_store
.
pop
(
tensor_id
,
None
)
tensor
=
self
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
send_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
addr
,
_
,
_
=
tensor
self
.
pool
.
free
(
addr
)
self
.
pool
.
free
(
addr
)
self
.
send_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
# TODO:Retrieve requests that have already sent the KV cache.
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending
:
set
[
str
]
=
set
()
finished_sending
:
set
[
str
]
=
set
()
...
@@ -694,19 +746,64 @@ class P2pNcclEngine:
...
@@ -694,19 +746,64 @@ class P2pNcclEngine:
# TODO:Retrieve requests that have already received the KV cache.
# TODO:Retrieve requests that have already received the KV cache.
finished_recving
:
set
[
str
]
=
set
()
finished_recving
:
set
[
str
]
=
set
()
if
self
.
kv_cache_layer_num
==
0
:
for
layer_name
in
forward_context
.
no_compile_layers
:
layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
getattr
(
layer
,
'kv_cache'
,
None
)
if
kv_cache
is
None
:
continue
self
.
kv_cache_layer_num
+=
1
with
self
.
recv_store_cv
:
for
req
in
self
.
recv_request_id_to_tensor_ids
:
if
len
(
self
.
recv_request_id_to_tensor_ids
[
req
])
==
self
.
kv_cache_layer_num
:
finished_recving
.
add
(
req
)
return
finished_sending
or
None
,
finished_recving
or
None
return
finished_sending
or
None
,
finished_recving
or
None
def
_ping
(
self
):
def
_ping
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"
{
self
.
zmq_address
}
_ping"
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
if
self
.
rank
==
0
:
data
=
{
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"type"
:
"P
_init
"
if
self
.
config
.
is_kv_producer
else
"D
_init
"
,
"http_address"
:
self
.
http_address
,
"http_address"
:
self
.
http_address
,
"zmq_address"
:
self
.
zmq_address
"dp_size"
:
self
.
dp_size
,
"pp_size"
:
self
.
pp_size
,
"tp_size"
:
self
.
tp_size
}
}
sock
.
send
(
msgpack
.
dumps
(
data
))
data
=
{
"type"
:
"P_rank"
if
self
.
config
.
is_kv_producer
else
"D_rank"
,
"http_address"
:
self
.
http_address
,
"dp_rank"
:
self
.
dp_rank
,
"pp_rank"
:
self
.
pp_rank
,
"tp_rank"
:
self
.
tp_rank
,
"zmq_address"
:
self
.
zmq_address
}
sock
.
send
(
msgpack
.
dumps
(
data
))
if
self
.
rank
!=
0
:
return
if
self
.
config
.
is_kv_producer
:
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
data
=
{
"type"
:
"P_unique_id"
,
"http_address"
:
self
.
http_address
,
"unique_id"
:
bytes
(
unique_id
.
internal
)
}
sock
.
send
(
msgpack
.
dumps
(
data
))
while
True
:
while
True
:
data
=
{
"type"
:
"heartbeat"
,
"http_address"
:
self
.
http_address
,
}
sock
.
send
(
msgpack
.
dumps
(
data
))
sock
.
send
(
msgpack
.
dumps
(
data
))
time
.
sleep
(
3
)
time
.
sleep
(
3
)
...
@@ -740,40 +837,6 @@ class P2pNcclEngine:
...
@@ -740,40 +837,6 @@ class P2pNcclEngine:
self
.
_listener_thread
.
join
()
self
.
_listener_thread
.
join
()
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
.
join
()
self
.
_send_thread
.
join
()
self
.
_pending_check_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
self
.
_ping_thread
.
join
()
def
compute_remote_pp_rank
(
self
,
layer_name
:
str
)
->
int
:
current_layer_idx
=
extract_layer_index
(
layer_name
)
for
d_pp_rank
in
range
(
self
.
remote_pp_size
):
start
,
end
=
get_pp_indices
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
logger
.
info
(
f
"""compute_remote_pp_rank : current_layer_idx:
{
current_layer_idx
}
start:
{
start
}
end:
{
end
}
"""
)
if
(
current_layer_idx
==
self
.
total_num_hidden_layers
):
return
self
.
remote_pp_size
-
1
if
start
<=
current_layer_idx
<
end
:
return
d_pp_rank
return
-
1
@
staticmethod
def
get_tensor_id
(
request_id
:
str
,
layer_name
:
str
)
->
str
:
return
request_id
+
"#"
+
layer_name
@
staticmethod
def
parse_request_id
(
request_id
:
str
,
is_prefill
=
True
)
->
tuple
[
str
,
int
]:
# Regular expression to match the string hostname and integer port
if
is_prefill
:
pattern
=
r
"___decode_addr_(.*):(\d+)"
else
:
pattern
=
r
"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match
=
regex
.
search
(
pattern
,
request_id
)
if
match
:
# Extract the ranks
ip
=
match
.
group
(
1
)
port
=
int
(
match
.
group
(
2
))
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
\ No newline at end of file
vllm/v1/core/sched/scheduler.py
View file @
4612aad6
...
@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
...
@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
,
dp_rank
=
self
.
parallel_config
.
data_parallel_rank
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
self
.
kv_events_config
,
...
@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
...
@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
break
break
request
=
self
.
waiting
.
peek_request
()
request
=
self
.
waiting
.
peek_request
()
if
request
.
is_finished
():
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
continue
# KVTransfer: skip request if still waiting for remote kvs.
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
...
@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
...
@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
# pooling requests to be chunked
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
num_new_tokens
>
token_budget
:
num_new_tokens
>
token_budget
:
break
self
.
waiting
.
pop_request
()
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
skipped_waiting_requests
.
prepend_request
(
request
)
continue
continue
...
@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
...
@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
break
break
request
=
self
.
waiting
.
peek_request
()
request
=
self
.
waiting
.
peek_request
()
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
...
@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
# pooling requests to be chunked
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
and
\
num_new_tokens
>
token_budget
:
num_new_tokens
>
token_budget
:
break
self
.
waiting
.
pop_request
()
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
skipped_waiting_requests
.
prepend_request
(
request
)
continue
continue
...
@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
spec_token_ids
is
not
None
and
(
self
.
connector
is
None
or
not
self
.
connector
.
is_producer
)
:
if
self
.
structured_output_manager
.
should_advance
(
request
):
if
self
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
# Needs to happen after new_token_ids are accepted.
...
...
vllm/v1/engine/core.py
View file @
4612aad6
...
@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
...
@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
if
isinstance
(
request
,
EngineCoreRequest
)
and
self
.
scheduler
.
connector
is
not
None
:
if
request_type
==
EngineCoreRequestType
.
ADD
:
self
.
scheduler
.
connector
.
register_req
(
request
.
request_id
)
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
coord_output_path
:
Optional
[
str
],
coord_output_path
:
Optional
[
str
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment