Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
61ba33d5
Commit
61ba33d5
authored
Apr 10, 2026
by
xuxz
Committed by
xuxz
Apr 10, 2026
Browse files
[PD][Feat]支持pd分离dp并行
parent
ce47a56e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2385 additions
and
1 deletion
+2385
-1
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
...ed_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
+499
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+9
-0
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector_dp.py
...d/kv_transfer/kv_connector/v1/du/du_swift_connector_dp.py
+772
-0
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine_dp.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_engine_dp.py
+1089
-0
vllm/envs.py
vllm/envs.py
+3
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+7
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+6
-0
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_dp.py
0 → 100644
View file @
61ba33d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
socket
import
threading
import
uuid
import
aiohttp
import
msgpack
import
zmq
from
typing
import
Any
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
import
time
import
asyncio
from
collections
import
deque
,
defaultdict
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Request
:
request_id
:
str
p_http_address
:
str
=
""
p_dp_rank
:
int
=
-
1
d_http_address
:
str
=
""
d_dp_rank
:
int
=
-
1
@
dataclass
class
Instance
:
ins_type
:
str
=
"P"
http_address
:
str
=
""
zmq_address
:
str
=
""
p_unique_id
:
bytes
=
b
""
dp_size
:
int
=
0
pp_size
:
int
=
0
tp_size
:
int
=
0
# [dp, pp, tp] : zmq_address
rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
str
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
# [dp, pp, tp] : global rank
comm_rank_table
:
dict
[
int
,
dict
[
int
,
dict
[
int
,
int
]]]
=
field
(
default_factory
=
lambda
:
defaultdict
(
lambda
:
defaultdict
(
dict
))
)
def
count_rank_table_elements
(
self
):
count
=
0
for
first_dict
in
self
.
rank_table
.
values
():
for
second_dict
in
first_dict
.
values
():
count
+=
len
(
second_dict
)
return
count
def
is_ready
(
self
):
world_size
=
self
.
dp_size
*
self
.
pp_size
*
self
.
tp_size
inited_rank
=
self
.
count_rank_table_elements
()
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
count
=
0
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
running_requests
:
dict
[
str
,
Request
]
=
{}
healthy_instances
:
dict
[
str
,
float
]
=
{}
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
ready_prefill_ins
:
list
[
str
]
=
[]
ready_decode_ins
:
list
[
str
]
=
[]
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
instance_cv
=
threading
.
Condition
()
request_cv
=
threading
.
Condition
()
health_cv
=
threading
.
Condition
()
request_queue_cv
=
threading
.
Condition
()
request_queue
:
deque
[
list
[
Any
]]
=
deque
()
sock_cache
:
dict
[
str
,
Any
]
=
{}
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
socks
=
dict
(
poller
.
poll
())
if
router_socket
in
socks
:
remote_address
,
message
=
router_socket
.
recv_multipart
()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
global
prefill_instances
global
instance_cv
global
decode_instances
if
data
[
"type"
]
==
"P"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
])
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add P rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"D"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
])
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
rank_table
[
int
(
data
[
"dp_rank"
])][
int
(
data
[
"pp_rank"
])][
int
(
data
[
"tp_rank"
])]
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
logger
.
info
(
f
"""[Router] add D rank [
{
data
[
"dp_rank"
]
}
,
{
data
[
"pp_rank"
]
}
,
{
data
[
"tp_rank"
]
}
] :
{
data
[
"zmq_address"
]
}
"""
)
elif
data
[
"type"
]
==
"P_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
prefill_instances
:
prefill_instances
[
data
[
"http_address"
]]
=
Instance
(
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
prefill_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
p_instance
=
prefill_instances
[
data
[
"http_address"
]]
p_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
p_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
p_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
p_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
p_instance
.
is_ready
():
pending_prefill_ins
.
append
(
p_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_prefill_ins appended
{
p_instance
.
http_address
}
ZMQ:
{
p_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"D_init"
:
with
instance_cv
:
if
data
[
"http_address"
]
not
in
decode_instances
:
decode_instances
[
data
[
"http_address"
]]
=
Instance
(
ins_type
=
"D"
,
http_address
=
data
[
"http_address"
],
dp_size
=
int
(
data
[
"dp_size"
]),
pp_size
=
int
(
data
[
"pp_size"
]),
tp_size
=
int
(
data
[
"tp_size"
]))
decode_instances
[
data
[
"http_address"
]].
zmq_address
=
data
[
"zmq_address"
]
continue
d_instance
=
decode_instances
[
data
[
"http_address"
]]
d_instance
.
dp_size
=
int
(
data
[
"dp_size"
])
d_instance
.
pp_size
=
int
(
data
[
"pp_size"
])
d_instance
.
tp_size
=
int
(
data
[
"tp_size"
])
d_instance
.
zmq_address
=
data
[
"zmq_address"
]
if
d_instance
.
is_ready
():
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"heartbeat"
:
global
healthy_instances
global
health_cv
with
health_cv
:
healthy_instances
[
data
[
"http_address"
]]
=
time
.
time
()
elif
data
[
"type"
]
==
"Req"
:
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global
running_requests
global
request_cv
with
request_cv
:
if
data
[
"request_id"
]
in
running_requests
:
request
=
running_requests
[
data
[
"request_id"
]]
if
data
[
"instance_type"
]
==
"P"
:
request
.
p_http_address
=
data
[
"http_address"
]
request
.
p_dp_rank
=
int
(
data
[
"dp_rank"
])
elif
data
[
"instance_type"
]
==
"D"
:
request
.
d_http_address
=
data
[
"http_address"
]
request
.
d_dp_rank
=
int
(
data
[
"dp_rank"
])
assert
(
request
.
p_dp_rank
>=
0
and
request
.
d_dp_rank
>=
0
)
with
request_queue_cv
:
request_queue
.
append
(
request
)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv
.
notify
()
else
:
if
data
[
"instance_type"
]
==
"P"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
p_http_address
=
data
[
"http_address"
],
p_dp_rank
=
int
(
data
[
"dp_rank"
]))
elif
data
[
"instance_type"
]
==
"D"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
d_http_address
=
data
[
"http_address"
],
d_dp_rank
=
int
(
data
[
"dp_rank"
]))
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
data
,
)
zmq_context
=
None
tp_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
str
]]]
=
{}
tp_comm_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
int
]]]
=
{}
active_p_tp_rank_of_pd_pair
:
dict
[
str
,
set
[
int
]]
=
{}
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
hostname
=
socket
.
gethostname
()
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global
zmq_context
zmq_context
=
zmq
.
Context
()
router_socket
=
zmq_context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
poller
.
register
(
router_socket
,
zmq
.
POLLIN
)
_listener_thread
=
threading
.
Thread
(
target
=
_listen_for_register
,
args
=
[
poller
,
router_socket
],
daemon
=
True
)
_listener_thread
.
start
()
return
_listener_thread
def
dispatch_to_P
(
request
:
Request
):
global
prefill_instances
global
decode_instances
p_ins
=
prefill_instances
[
request
.
p_http_address
]
d_ins
=
decode_instances
[
request
.
d_http_address
]
global
zmq_context
global
sock_cache
pd_pair_id
=
p_ins
.
http_address
+
"_"
+
d_ins
.
http_address
p_dp_rank
=
request
.
p_dp_rank
d_dp_rank
=
request
.
d_dp_rank
tp_dst_id
=
pd_pair_id
+
"_"
+
str
(
d_dp_rank
)
assert
(
d_ins
.
pp_size
==
1
)
d_pp_rank
=
0
global
tp_mapping_of_pd_pair
global
tp_comm_mapping_of_pd_pair
global
active_p_tp_rank_of_pd_pair
if
tp_dst_id
not
in
active_p_tp_rank_of_pd_pair
:
p_active_tp_rank
=
set
()
p_tp_rank_to_dst
:
dict
[
int
,
list
[
str
]]
=
defaultdict
(
list
)
p_tp_rank_to_dst_comm
:
dict
[
int
,
list
[
int
]]
=
defaultdict
(
list
)
for
d_tp_rank
in
range
(
d_ins
.
tp_size
):
p_tp_rank
=
d_tp_rank
%
p_ins
.
tp_size
p_active_tp_rank
.
add
(
p_tp_rank
)
p_tp_rank_to_dst
[
p_tp_rank
].
append
(
d_ins
.
rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
p_tp_rank_to_dst_comm
[
p_tp_rank
].
append
(
d_ins
.
comm_rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
tp_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
=
p_active_tp_rank
p_active_tp_rank
=
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst
=
tp_mapping_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst_comm
=
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
for
p_tp_rank
in
p_active_tp_rank
:
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_to_transfer"
,
"request_id"
:
request
.
request_id
,
"dst_num"
:
len
(
p_tp_rank_to_dst
[
p_tp_rank
]),
"pd_pair_id"
:
pd_pair_id
,
"remote_address"
:
p_tp_rank_to_dst
[
p_tp_rank
],
"remote_rank"
:
p_tp_rank_to_dst_comm
[
p_tp_rank
],
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
logger
.
info
(
f
"""[Router] dispatch Request
{
request
.
request_id
}
[
{
p_dp_rank
}
,
{
p_pp_rank
}
,
{
p_tp_rank
}
] -> [
{
d_dp_rank
}
,
{
d_pp_rank
}
]"""
)
for
p_tp_rank
in
range
(
p_ins
.
tp_size
):
if
p_tp_rank
not
in
p_active_tp_rank
:
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_not_to_transfer"
,
"request_id"
:
request
.
request_id
,
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
def
dp_dispatch
():
global
request_queue_cv
global
request_queue
while
True
:
with
request_queue_cv
:
while
not
request_queue
:
request_queue_cv
.
wait
()
request
=
request_queue
.
pop
()
dispatch_to_P
(
request
)
def
start_dp_dispatch
():
_thread
=
threading
.
Thread
(
target
=
dp_dispatch
,
daemon
=
True
)
_thread
.
start
()
return
_thread
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
app
=
Quart
(
__name__
)
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
async
def
forward_request
(
url
,
data
,
request_id
):
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
"X-Request-Id"
:
request_id
,
}
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
if
True
:
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
else
:
content
=
await
response
.
read
()
yield
content
def
unique_id_dispatch
(
prefill_instance
:
Instance
,
decode_instance
:
Instance
)
:
global
zmq_context
global
sock_cache
global
router_nccl
global
pd_pair
pd_pair_id
=
prefill_instance
.
http_address
+
"_"
+
decode_instance
.
http_address
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
return
logger
.
info
(
f
"""[Router] initing pd pair
{
pd_pair_id
}
"""
)
unique_id
=
router_nccl
.
ncclGetUniqueId
()
unique_id
=
bytes
(
unique_id
.
internal
)
rank
=
0
p_rank_num
=
prefill_instance
.
dp_size
*
prefill_instance
.
pp_size
*
prefill_instance
.
tp_size
d_rank_num
=
decode_instance
.
dp_size
*
decode_instance
.
pp_size
*
decode_instance
.
tp_size
world_size
=
p_rank_num
+
d_rank_num
for
dp_rank
in
range
(
prefill_instance
.
dp_size
):
for
pp_rank
in
range
(
prefill_instance
.
pp_size
):
for
tp_rank
in
range
(
prefill_instance
.
tp_size
):
if
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
prefill_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
prefill_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [P] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
for
dp_rank
in
range
(
decode_instance
.
dp_size
):
for
pp_rank
in
range
(
decode_instance
.
pp_size
):
for
tp_rank
in
range
(
decode_instance
.
tp_size
):
if
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
}
"
)
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]]
=
sock
data
=
{
"cmd"
:
"comm_init"
,
"pd_pair_id"
:
pd_pair_id
,
"unique_id"
:
unique_id
,
"world_size"
:
world_size
,
"rank"
:
rank
}
sock_cache
[
decode_instance
.
rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
decode_instance
.
comm_rank_table
[
dp_rank
][
pp_rank
][
tp_rank
]
=
rank
rank
+=
1
logger
.
info
(
f
"""[Router] dispatch unique_id of pd pair
{
pd_pair_id
}
to [D] [
{
dp_rank
}
,
{
pp_rank
}
,
{
tp_rank
}
]"""
)
pd_pair
[
pd_pair_id
]
=
unique_id
def
pd_pair_init
():
global
prefill_instances
global
decode_instances
global
pending_prefill_ins
global
pending_decode_ins
global
ready_prefill_ins
global
ready_decode_ins
global
instance_cv
while
True
:
with
instance_cv
:
while
len
(
pending_prefill_ins
)
==
0
and
len
(
pending_decode_ins
)
==
0
:
logger
.
info
(
f
"""[Router] pd_pair_init: waiting for instance_cv"""
)
instance_cv
.
wait
()
logger
.
info
(
f
"""[Router] pd_pair_init: instance_cv finished waiting"""
)
while
pending_prefill_ins
:
p_ins
=
pending_prefill_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
p_ins
}
from pending_prefill_ins"""
)
for
d_ins
in
ready_decode_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_prefill_ins
.
append
(
p_ins
)
pending_prefill_ins
.
remove
(
p_ins
)
while
pending_decode_ins
:
d_ins
=
pending_decode_ins
[
0
]
logger
.
info
(
f
"""[Router] pd_pair_init: processing
{
d_ins
}
from pending_decode_ins"""
)
for
p_ins
in
ready_prefill_ins
:
unique_id_dispatch
(
prefill_instances
[
p_ins
],
decode_instances
[
d_ins
])
ready_decode_ins
.
append
(
d_ins
)
pending_decode_ins
.
remove
(
d_ins
)
def
start_pd_pair_init
():
_thread
=
threading
.
Thread
(
target
=
pd_pair_init
,
daemon
=
True
)
_thread
.
start
()
return
_thread
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
global
count
global
prefill_instances
global
instance_cv
with
instance_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
with
instance_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
global
pd_pair
if
prefill_instance
.
http_address
+
"_"
+
decode_instance
.
http_address
not
in
pd_pair
:
raise
RuntimeError
(
"Selected PD pair was not inited"
)
logger
.
info
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, 👉 HTTP:
{
decode_addr
}
]"
)
count
+=
1
request_id
=
f
"
{
random_uuid
()
}
"
async
def
run_prefill
():
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
):
pass
prefill_task
=
asyncio
.
create_task
(
run_prefill
())
# return decode
generator
=
forward_request
(
f
"http://
{
decode_addr
}
/v1/completions"
,
original_request_data
,
request_id
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
return
response
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_pd_pair_init
()
t_2
=
start_dp_dispatch
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t_1
.
join
()
t_2
.
join
()
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
61ba33d5
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
from
vllm
import
envs
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
...
...
@@ -45,6 +46,7 @@ class KVConnectorFactory:
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
dp_rank
:
int
=
-
1
,
)
->
KVConnectorBase
:
kv_transfer_config
=
config
.
kv_transfer_config
if
kv_transfer_config
is
None
:
...
...
@@ -77,6 +79,8 @@ class KVConnectorFactory:
if
compat_sig
:
# Old signature: __init__(self, vllm_config, role)
return
connector_cls
(
config
,
role
)
elif
envs
.
VLLM_USE_DP_CONNECTOR
:
return
connector_cls
(
config
,
role
,
kv_cache_config
,
dp_rank
)
else
:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return
connector_cls
(
config
,
role
,
kv_cache_config
)
...
...
@@ -160,6 +164,11 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector"
,
"DuSwiftConnector"
)
KVConnectorFactory
.
register_connector
(
"DuSwiftConnectorDp"
,
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector_dp"
,
"DuSwiftConnectorDp"
)
KVConnectorFactory
.
register_connector
(
"LMCacheConnectorV1"
,
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector"
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector_dp.py
0 → 100644
View file @
61ba33d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
regex
as
re
import
torch
import
os
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_engine_dp
import
(
DuSwiftEngineDp
,
RemoteAddr
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
import
zmq
import
msgpack
if
TYPE_CHECKING
:
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
@
dataclass
class
ReqMeta
:
# Request Id
request_id
:
str
# Request tokens
token_ids
:
torch
.
Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping
:
torch
.
Tensor
slot_mapping_device
:
torch
.
Tensor
=
None
@
staticmethod
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
)
->
"ReqMeta"
:
valid_num_tokens
=
len
(
token_ids
)
token_ids_tensor
=
torch
.
tensor
(
token_ids
)
block_ids_tensor
=
torch
.
tensor
(
block_ids
)
num_blocks
=
block_ids_tensor
.
shape
[
0
]
block_offsets
=
torch
.
arange
(
0
,
block_size
)
slot_mapping
=
block_offsets
.
reshape
((
1
,
block_size
))
+
\
block_ids_tensor
.
reshape
((
num_blocks
,
1
))
*
block_size
slot_mapping
=
slot_mapping
.
flatten
()[:
valid_num_tokens
]
return
ReqMeta
(
request_id
=
request_id
,
token_ids
=
token_ids_tensor
,
slot_mapping
=
slot_mapping
,
)
@
dataclass
class
DuSwiftConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
def
__init__
(
self
):
self
.
requests
=
[]
def
add_request
(
self
,
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
,
)
->
None
:
self
.
requests
.
append
(
ReqMeta
.
make_meta
(
request_id
,
token_ids
,
block_ids
,
block_size
))
class
DuSwiftConnectorDp
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
dp_rank
:
int
=
-
1
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
,)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
self
.
config
=
vllm_config
.
kv_transfer_config
self
.
is_producer
=
self
.
config
.
is_kv_producer
self
.
chunked_prefill
:
dict
[
str
,
Any
]
=
{}
self
.
_rank
=
get_world_group
().
rank
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_local_rank
=
get_world_group
().
local_rank
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_rank
=
get_dp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_rank
=
get_pp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_rank
=
get_tp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_size
=
get_dp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_size
=
get_pp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
du_swift_engine
=
DuSwiftEngineDp
(
local_rank
=
self
.
_local_rank
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
dp_rank
=
self
.
_dp_rank
,
pp_rank
=
self
.
_pp_rank
,
tp_rank
=
self
.
_tp_rank
,
dp_size
=
self
.
_dp_size
,
pp_size
=
self
.
_pp_size
,
tp_size
=
self
.
_tp_size
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
model_config
=
vllm_config
.
model_config
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
"num_hidden_layers"
,
0
)
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
if
not
config_file
:
print
(
"Warning: Please set the IPVNet FILE environment variable for cross machine recognition of the second IP address"
)
return
try
:
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
for
line_num
,
line
in
enumerate
(
file
,
1
):
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
'#'
):
ips
=
line
.
split
()
if
len
(
ips
)
==
2
:
first_ip
,
second_ip
=
ips
if
first_ip
not
in
self
.
ip_map
:
self
.
ip_map
[
first_ip
]
=
second_ip
else
:
print
(
f
"warning: num
{
line_num
}
Incorrect format :
{
line
}
"
)
except
Exception
as
e
:
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
dp_rank
=
dp_rank
proxy_ip
=
self
.
config
.
get_from_extra_config
(
"proxy_ip"
,
""
)
proxy_port
=
self
.
config
.
get_from_extra_config
(
"proxy_port"
,
""
)
if
proxy_ip
==
""
or
proxy_port
==
""
:
self
.
proxy_address
=
""
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
self
.
context
=
zmq
.
Context
()
req_sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
req_sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"
{
self
.
http_address
}
_rank
{
self
.
dp_rank
}
"
)
req_sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
self
.
req_sock
=
req_sock
def
get_ip_value
(
self
,
key
):
return
self
.
ip_map
.
get
(
key
)
def
register_req
(
self
,
request_id
:
str
)
:
data
=
{
"type"
:
"Req"
,
"instance_type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"request_id"
:
request_id
,
"dp_rank"
:
self
.
dp_rank
}
self
.
req_sock
.
send
(
msgpack
.
dumps
(
data
))
# ==============================
# Worker-side methods
# ==============================
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if
self
.
is_producer
:
return
assert
self
.
du_swift_engine
is
not
None
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
return
def
inject_kv_into_layer
(
dst_kv_cache_layer
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
request_id
:
str
,
)
->
None
:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
request_id (str): request id for log
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
dst_kv_cache_layer
.
ndim
==
3
:
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
0
)
num_token
=
src_kv_cache
.
shape
[
0
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[
slot_mapping
,
...]
=
src_kv_cache
else
:
dst_kv_cache_layer
[
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
self
.
check_tensors_except_dim
(
dst_kv_cache_layer
,
src_kv_cache
,
1
)
num_token
=
src_kv_cache
.
shape
[
1
]
if
len
(
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[:,
slot_mapping
,
...]
=
src_kv_cache
else
:
dst_kv_cache_layer
[:,
slot_mapping
[:
num_token
],
...]
=
src_kv_cache
logger
.
warning
(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s"
,
len
(
slot_mapping
),
num_token
,
request_id
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
# Get the metadata
metadata
:
KVConnectorMetadata
=
\
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
DuSwiftConnectorMetadata
)
if
metadata
is
None
:
return
# Load the KV for each request each layer
for
request
in
metadata
.
requests
:
for
layer_name
in
forward_context
.
no_compile_layers
:
layer
=
forward_context
.
no_compile_layers
[
layer_name
]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache
=
getattr
(
layer
,
'kv_cache'
,
None
)
if
kv_cache
is
None
:
continue
kv_cache_layer
=
kv_cache
[
\
forward_context
.
virtual_engine
]
if
not
envs
.
VLLM_P2P_ASYNC
:
kv_cache
=
self
.
du_swift_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
)
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
continue
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
request
.
request_id
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
if
tensor_id
in
self
.
du_swift_engine
.
recv_store
:
tensor
=
self
.
du_swift_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
du_swift_engine
.
send_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
self
.
du_swift_engine
.
recv_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
du_swift_engine
.
pool
.
free
(
addr
)
else
:
dst_kv_cache_layer_shape
=
kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
assert
kv_cache_layer
.
is_contiguous
()
dst_kv_cache_layer
=
kv_cache_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
assert
kv_cache_layer
.
is_contiguous
()
dst_kv_cache_layer
=
kv_cache_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
inject_start_index
=
0
for
num
in
range
(
self
.
du_swift_engine
.
tensor_split_num
):
kv_cache
=
self
.
du_swift_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
))
if
kv_cache
is
None
:
logger
.
warning
(
"🚧src_kv_cache is None, %s"
,
request
.
request_id
)
continue
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
num_token
=
kv_cache
.
shape
[
0
]
if
len
(
request
.
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[
request
.
slot_mapping
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[
request
.
slot_mapping
[
inject_start_index
:
inject_start_index
+
num_token
],
...]
=
kv_cache
else
:
num_token
=
kv_cache
.
shape
[
1
]
if
len
(
request
.
slot_mapping
)
==
num_token
:
dst_kv_cache_layer
[:,
request
.
slot_mapping
,
...]
=
kv_cache
else
:
dst_kv_cache_layer
[:,
request
.
slot_mapping
[
inject_start_index
:
inject_start_index
+
num_token
],
...]
=
kv_cache
inject_start_index
+=
num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
+
"#"
+
str
(
num
)
if
tensor_id
in
self
.
du_swift_engine
.
recv_store
:
tensor
=
self
.
du_swift_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
du_swift_engine
.
send_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
self
.
du_swift_engine
.
recv_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
du_swift_engine
.
pool
.
free
(
addr
)
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer_shape
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
)
->
None
:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if
not
self
.
is_producer
:
return
assert
self
.
du_swift_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
kv_layer
.
ndim
==
3
:
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
DuSwiftConnectorMetadata
)
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
slot_mapping
=
request
.
slot_mapping
if
request
.
slot_mapping_device
is
None
:
request
.
slot_mapping_device
=
\
request
.
slot_mapping
.
pin_memory
().
to
(
device
=
kv_layer
.
device
,
non_blocking
=
True
)
slot_mapping
=
request
.
slot_mapping_device
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
.
record
()
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
\
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
),
tbo_evt
)
else
:
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
remote_address
,
tbo_evt
)
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
),
tbo_evt
)
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
self
.
du_swift_engine
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
(
kv_layer
,
slot_mapping
),
ip
+
":"
+
str
(
port
+
i
),
tbo_evt
)
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
else
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
# ip, port = self.parse_request_id(request_id, True)
# p_ip, p_port = self.parse_request_id(request_id, False)
# remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
# pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
# ) % self.parallel_config.pipeline_parallel_size
# if (self.multiple_machines_p and self.multiple_machines_d):
# ip_second = self.get_ip_value(ip)
# if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank))
# else:
# logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.multiple_machines_p and not self.multiple_machines_d):
# if (self.pp_size == 2):
# remote_address = ip + ":" + str(port + self._tp_rank)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
# elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending
=
False
with
self
.
du_swift_engine
.
req_status_cv
:
if
request_id
not
in
self
.
du_swift_engine
.
req_status
:
pending
=
True
if
pending
:
self
.
du_swift_engine
.
pending_tensor
(
request_id
,
layer_name
,
kv_cache
)
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
else
:
req_data
=
self
.
du_swift_engine
.
req_status
[
request_id
]
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
for
i
in
range
(
req_data
.
dst_num
):
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_addr
)
def
wait_for_save
(
self
):
pass
# if self.is_producer:
# assert self.du_swift_engine is not None
# self.du_swift_engine.wait_for_sent()
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
],
**
kwargs
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert
self
.
du_swift_engine
is
not
None
forward_context
:
ForwardContext
=
get_forward_context
()
return
self
.
du_swift_engine
.
get_finished
(
finished_req_ids
,
forward_context
)
# ==============================
# Scheduler-side methods
# ==============================
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if
self
.
is_producer
:
return
0
,
False
num_external_tokens
=
(
len
(
request
.
prompt_token_ids
)
-
1
-
num_computed_tokens
)
if
num_external_tokens
<
0
:
num_external_tokens
=
0
return
num_external_tokens
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
"""
Update KVConnector state after block allocation.
"""
if
not
self
.
is_producer
and
num_external_tokens
>
0
:
self
.
_requests_need_load
[
request
.
request_id
]
=
(
request
,
blocks
.
get_block_ids
()[
0
])
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
,
)
->
KVConnectorMetadata
:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta
=
DuSwiftConnectorMetadata
()
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
self
.
is_producer
:
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
new_req
.
req_id
]
num_tokens
=
num_scheduled_tokens
+
new_req
.
num_computed_tokens
# the request's prompt is chunked prefill
if
num_tokens
<
len
(
new_req
.
prompt_token_ids
):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self
.
chunked_prefill
[
new_req
.
req_id
]
=
(
new_req
.
block_ids
[
0
],
new_req
.
prompt_token_ids
)
continue
# the request's prompt is not chunked prefill
meta
.
add_request
(
request_id
=
new_req
.
req_id
,
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
[
0
],
block_size
=
self
.
_block_size
)
continue
if
new_req
.
req_id
in
self
.
_requests_need_load
:
meta
.
add_request
(
request_id
=
new_req
.
req_id
,
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
[
0
],
block_size
=
self
.
_block_size
)
self
.
_requests_need_load
.
pop
(
new_req
.
req_id
)
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
resumed_from_preemption
=
req_id
in
cached_reqs
.
resumed_req_ids
if
self
.
is_producer
:
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
# assert req_id in self.chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
block_ids
=
new_block_ids
[
0
]
if
not
resumed_from_preemption
:
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
prompt_token_ids
=
self
.
chunked_prefill
[
req_id
][
1
]
# the request's prompt is chunked prefill again
if
num_tokens
<
len
(
prompt_token_ids
):
self
.
chunked_prefill
[
req_id
]
=
(
block_ids
,
prompt_token_ids
)
continue
# the request's prompt is all prefilled finally
meta
.
add_request
(
request_id
=
req_id
,
token_ids
=
prompt_token_ids
,
block_ids
=
block_ids
,
block_size
=
self
.
_block_size
)
self
.
chunked_prefill
.
pop
(
req_id
,
None
)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if
not
resumed_from_preemption
:
break
if
req_id
in
self
.
_requests_need_load
:
request
,
_
=
self
.
_requests_need_load
.
pop
(
req_id
)
total_tokens
=
num_computed_tokens
+
1
token_ids
=
request
.
all_token_ids
[:
total_tokens
]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids
=
new_block_ids
[
0
]
meta
.
add_request
(
request_id
=
req_id
,
token_ids
=
token_ids
,
block_ids
=
block_ids
,
block_size
=
self
.
_block_size
)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self
.
_requests_need_load
.
clear
()
return
meta
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
Optional
[
dict
[
str
,
Any
]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self
.
chunked_prefill
.
pop
(
request
.
request_id
,
None
)
return
False
,
None
# ==============================
# Static methods
# ==============================
@
staticmethod
def
parse_request_id
(
request_id
:
str
,
is_prefill
=
True
)
->
tuple
[
str
,
int
]:
# Regular expression to match the string hostname and integer port
if
is_prefill
:
pattern
=
r
"___decode_addr_(.*):(\d+)"
else
:
pattern
=
r
"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match
=
re
.
search
(
pattern
,
request_id
)
if
match
:
# Extract the ranks
ip
=
match
.
group
(
1
)
port
=
int
(
match
.
group
(
2
))
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
@
staticmethod
def
check_tensors_except_dim
(
tensor1
,
tensor2
,
dim
):
shape1
=
tensor1
.
size
()
shape2
=
tensor2
.
size
()
if
len
(
shape1
)
!=
len
(
shape2
)
or
not
all
(
s1
==
s2
for
i
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shape1
,
shape2
))
if
i
!=
dim
):
raise
NotImplementedError
(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs."
)
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine_dp.py
0 → 100644
View file @
61ba33d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
import
os
import
threading
import
time
import
typing
from
collections
import
deque
,
defaultdict
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgpack
import
torch
import
zmq
import
regex
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
)
from
vllm.distributed.kv_transfer.kv_connector.v1.du.tensor_memory_pool
import
(
# noqa: E501
TensorMemoryPool
)
from
vllm.utils.torch_utils
import
current_stream
from
vllm.utils.network_utils
import
get_ip
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
dataclasses
import
dataclass
,
field
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.config
import
ModelConfig
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_MEM_POOL_SIZE_GB
=
32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@
contextmanager
def
set_du_swift_context
(
num_channels
:
str
):
original_values
:
dict
[
str
,
Any
]
=
{}
env_vars
=
[
'NCCL_MAX_NCHANNELS'
,
'NCCL_MIN_NCHANNELS'
,
'NCCL_CUMEM_ENABLE'
,
'NCCL_BUFFSIZE'
,
'NCCL_PROTO'
,
# LL,LL128,SIMPLE
'NCCL_ALGO'
,
# RING,TREE
]
for
var
in
env_vars
:
original_values
[
var
]
=
os
.
environ
.
get
(
var
)
logger
.
info
(
"set_du_swift_context, original_values: %s"
,
original_values
)
try
:
os
.
environ
[
'NCCL_MAX_NCHANNELS'
]
=
num_channels
os
.
environ
[
'NCCL_MIN_NCHANNELS'
]
=
num_channels
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'1'
yield
finally
:
for
var
in
env_vars
:
if
original_values
[
var
]
is
not
None
:
os
.
environ
[
var
]
=
original_values
[
var
]
else
:
os
.
environ
.
pop
(
var
,
None
)
@
dataclass
class
ReqKVDest
:
dst_num
:
int
=
0
pd_pair_id
:
str
=
""
zmq_address_and_comm_rank
:
list
[
tuple
[
str
,
int
]]
=
field
(
default_factory
=
list
)
@
dataclass
class
RemoteAddr
:
pd_pair_id
:
str
=
""
zmq_address
:
str
=
""
comm_rank
:
int
=
0
class
DuSwiftEngineDp
:
def
__init__
(
self
,
local_rank
:
int
,
port_offset
:
int
,
config
:
KVTransferConfig
,
model_config
:
ModelConfig
,
dp_rank
:
int
=
0
,
pp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
dp_size
:
int
=
0
,
pp_size
:
int
=
0
,
tp_size
:
int
=
0
,
library_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
config
=
config
self
.
model_config
=
model_config
self
.
rank
=
port_offset
self
.
local_rank
=
local_rank
self
.
dp_rank
=
dp_rank
self
.
pp_rank
=
pp_rank
self
.
tp_rank
=
tp_rank
self
.
dp_size
=
dp_size
self
.
pp_size
=
pp_size
self
.
tp_size
=
tp_size
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
"num_hidden_layers"
,
0
)
self
.
pp_rank
=
get_pp_group
().
rank_in_group
self
.
tp_rank
=
get_tp_group
().
rank_in_group
self
.
pp_size
=
get_pp_group
().
world_size
self
.
tp_size
=
get_tp_group
().
world_size
if
config
.
is_kv_producer
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
1
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
1
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
if
self
.
remote_tp_size
%
self
.
tp_size
!=
0
:
logger
.
error
(
" the Prefill TP size must be less than or equal to the Decode TP size!!!!"
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
self
.
multiple_machines
=
self
.
config
.
get_from_extra_config
(
"enable_multiple_machines"
,
False
)
self
.
instance_ip
=
self
.
config
.
get_from_extra_config
(
"instance_ip"
,
None
)
if
self
.
instance_ip
:
self
.
multiple_machines
=
False
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
self
.
_hostname
=
get_ip
()
self
.
_port
=
port
# Each card corresponds to a ZMQ address.
self
.
zmq_address
=
f
"
{
self
.
_hostname
}
:
{
self
.
_port
}
"
# The `http_port` must be consistent with the port of OpenAI.
if
self
.
instance_ip
:
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
else
:
self
.
http_address
=
(
f
"
{
self
.
_hostname
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip
=
self
.
config
.
get_from_extra_config
(
"proxy_ip"
,
""
)
proxy_port
=
self
.
config
.
get_from_extra_config
(
"proxy_port"
,
""
)
if
proxy_ip
==
""
or
proxy_port
==
""
:
self
.
proxy_address
=
""
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
kv_cache_layer_num
=
0
self
.
context
=
zmq
.
Context
()
self
.
router_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
router_socket
.
setsockopt
(
zmq
.
RCVHWM
,
10000
)
self
.
router_socket
.
setsockopt
(
zmq
.
SNDHWM
,
5000
)
self
.
router_socket
.
setsockopt
(
zmq
.
LINGER
,
0
)
self
.
router_socket
.
setsockopt
(
zmq
.
ROUTER_MANDATORY
,
1
)
self
.
router_socket
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
self
.
router_socket
.
bind
(
f
"tcp://
{
self
.
zmq_address
}
"
)
self
.
poller
=
zmq
.
Poller
()
self
.
poller
.
register
(
self
.
router_socket
,
zmq
.
POLLIN
)
self
.
req_status
:
dict
[
str
,
ReqKVDest
]
=
{}
self
.
req_status_cv
=
threading
.
Condition
()
self
.
send_store_cv
=
threading
.
Condition
()
self
.
send_queue_cv
=
threading
.
Condition
()
self
.
recv_store_cv
=
threading
.
Condition
()
self
.
pending_queue_cv
=
threading
.
Condition
()
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
self
.
p2p_async_kv_tokens
=
envs
.
VLLM_P2P_BUF_TOKENS
self
.
p2p_async_buf
=
None
self
.
tensor_split_num
:
int
=
0
mem_pool_size_gb
=
self
.
config
.
get_from_extra_config
(
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
)
self
.
pool
=
TensorMemoryPool
(
max_block_size
=
int
(
mem_pool_size_gb
)
*
1024
**
3
)
# GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self
.
send_type
=
self
.
config
.
get_from_extra_config
(
"send_type"
,
"PUT"
)
if
self
.
send_type
==
"GET"
:
# tensor_id: torch.Tensor
self
.
send_store
:
dict
[
str
,
torch
.
Tensor
]
=
{}
else
:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self
.
send_queue
:
deque
[
list
[
Any
]]
=
deque
()
self
.
pending_queue
:
dict
[
str
,
list
[
list
[
Any
]]]
=
defaultdict
(
list
)
self
.
requests_to_release
:
dict
[
str
,
bool
]
=
{}
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_send_async
,
daemon
=
True
)
self
.
_send_thread
.
start
()
self
.
_pending_check_thread
=
threading
.
Thread
(
target
=
self
.
_pending_check
,
daemon
=
True
)
self
.
_pending_check_thread
.
start
()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self
.
recv_store
:
dict
[
str
,
Any
]
=
{}
self
.
recv_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
self
.
socks
:
dict
[
str
,
Any
]
=
{}
# remote_address: client socket
self
.
comms
:
dict
[
str
,
Any
]
=
{}
# remote_address: (ncclComm_t, rank)
self
.
buffer_size
=
0
self
.
buffer_size_threshold
=
float
(
self
.
config
.
kv_buffer_size
)
self
.
nccl_num_channels
=
self
.
config
.
get_from_extra_config
(
"nccl_num_channels"
,
"8"
)
self
.
_listener_thread
=
threading
.
Thread
(
target
=
self
.
_listen_for_requests
,
daemon
=
True
)
self
.
_listener_thread
.
start
()
self
.
_ping_thread
=
None
if
self
.
multiple_machines
:
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
else
:
if
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping_new
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
logger
.
info
(
"💯DuSwiftEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s"
,
self
.
rank
,
self
.
local_rank
,
self
.
http_address
,
self
.
zmq_address
,
self
.
proxy_address
,
self
.
send_type
,
self
.
buffer_size_threshold
,
self
.
nccl_num_channels
)
def
_create_connect_new
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt
(
zmq
.
SNDHWM
,
10000
)
sock
.
setsockopt
(
zmq
.
RCVHWM
,
5000
)
sock
.
setsockopt
(
zmq
.
LINGER
,
0
)
sock
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"P-
{
self
.
zmq_address
}
"
)
sock
.
connect
(
f
"tcp://
{
remote_address
}
"
)
self
.
socks
[
remote_address
]
=
sock
return
self
.
socks
[
remote_address
]
def
_create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
remote_address
}
"
)
self
.
socks
[
remote_address
]
=
sock
if
remote_address
in
self
.
comms
:
logger
.
info
(
"👋comm exists, remote_address:%s, comms:%s"
,
remote_address
,
self
.
comms
)
return
sock
,
self
.
comms
[
remote_address
]
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
data
=
{
"cmd"
:
"NEW"
,
"unique_id"
:
bytes
(
unique_id
.
internal
)}
sock
.
send
(
msgpack
.
dumps
(
data
))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
0
with
set_du_swift_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👉%s, MyRank: %s"
,
self
.
zmq_address
,
remote_address
,
rank
)
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
def
get_send_queue_items
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
)
->
list
[
any
]:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
remote_ip
,
remote_port
=
self
.
parse_request_id
(
request_id
,
True
)
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
pd_pair_id
=
p_ip
+
":"
+
str
(
p_port
)
+
"_"
+
remote_ip
+
":"
+
str
(
remote_port
)
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
remote_addr
=
RemoteAddr
(
pd_pair_id
,
remote_address
,
self
.
rank
+
self
.
pp_size
*
self
.
tp_size
)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return
[(
tensor_id
,
remote_addr
,
tensor
)]
if
not
is_mla
:
logger
.
error
(
" DuSwift only support mla model symmetric PP/TP!!!!"
)
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
remote_port_offset
=
remote_pp_rank
*
self
.
remote_tp_size
+
d_tp_rank
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
remote_port_offset
)
remote_addr
=
RemoteAddr
(
pd_pair_id
,
remote_address
,
remote_port_offset
+
self
.
pp_size
*
self
.
tp_size
)
logger
.
debug
(
"Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)"
,
tensor_id
,
tensor
.
shape
,
self
.
pp_rank
,
self
.
tp_rank
,
remote_address
,
remote_pp_rank
,
self
.
rank
*
mul_tp
+
self
.
rank
,
self
.
rank
,
remote_port_offset
+
self
.
pp_size
*
self
.
tp_size
)
items
.
append
([
tensor_id
,
remote_addr
,
tensor
])
return
items
def
send_tensor_new
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
=
False
,
)
->
bool
:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
_send_sync_new
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
):
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
if
self
.
send_type
==
"GET"
:
logger
.
error
(
" DuSwift new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!"
)
def
send_tensor
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
RemoteAddr
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store_cv
.
notify
()
return
True
else
:
if
self
.
send_type
==
"PUT"
:
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
elif
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
tensor
])
self
.
send_queue_cv
.
notify
()
else
:
# GET
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
while
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
oldest_tenser_id
=
next
(
iter
(
self
.
send_store
))
oldest_tenser
=
self
.
send_store
.
pop
(
oldest_tenser_id
)
oldest_tenser_size
=
oldest_tenser
.
element_size
(
)
*
oldest_tenser
.
numel
()
self
.
buffer_size
-=
oldest_tenser_size
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
buffer_size
+=
tensor_size
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
def
p2p_async_send_tensor
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store_cv
.
notify
()
return
True
else
:
if
self
.
send_type
==
"PUT"
:
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
elif
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
kv_layer
,
slot_mapping
=
tensor
# tesor (kv_layer, slot_mapping)
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
])
self
.
send_queue_cv
.
notify
()
else
:
# GET
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
while
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
oldest_tenser_id
=
next
(
iter
(
self
.
send_store
))
oldest_tenser
=
self
.
send_store
.
pop
(
oldest_tenser_id
)
oldest_tenser_size
=
oldest_tenser
.
element_size
(
)
*
oldest_tenser
.
numel
()
self
.
buffer_size
-=
oldest_tenser_size
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
buffer_size
+=
tensor_size
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
def
pending_tensor
(
self
,
reuqest_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
tbo_evt
=
None
,
)
->
bool
:
with
self
.
pending_queue_cv
:
self
.
pending_queue
[
reuqest_id
].
append
([
layer_name
,
tensor
,
tbo_evt
])
self
.
pending_queue_cv
.
notify
()
return
True
def
unpending_tensor
(
self
,
request_id
:
str
,
req_data
:
ReqKVDest
,
)
->
bool
:
with
self
.
pending_queue_cv
:
tensor_list
=
self
.
pending_queue
.
pop
(
request_id
)
if
request_id
in
self
.
requests_to_release
:
self
.
requests_to_release
[
request_id
]
=
True
logger
.
info
(
"[%d] unpending request: %s"
,
self
.
rank
,
request_id
)
if
req_data
.
dst_num
<=
0
:
return
False
for
layer_name
,
tensor
,
tbo_evt
in
tensor_list
:
for
i
in
range
(
req_data
.
dst_num
)
:
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
,
tbo_evt
)
else
:
self
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
)
return
True
def
_pending_check
(
self
)
:
while
True
:
with
self
.
pending_queue_cv
:
while
not
self
.
pending_queue
:
self
.
pending_queue_cv
.
wait
()
pending_queue
=
self
.
pending_queue
.
copy
()
for
request_id
in
pending_queue
:
with
self
.
req_status_cv
:
if
request_id
not
in
self
.
req_status
:
continue
req_data
=
self
.
req_status
[
request_id
]
assert
(
len
(
req_data
.
zmq_address_and_comm_rank
)
==
req_data
.
dst_num
)
self
.
unpending_tensor
(
request_id
,
req_data
)
def
recv_tensor
(
self
,
tensor_id
:
str
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
send_type
==
"PUT"
or
self
.
send_type
==
"PUT_ASYNC"
:
start_time
=
time
.
time
()
with
self
.
recv_store_cv
:
while
tensor_id
not
in
self
.
recv_store
:
self
.
recv_store_cv
.
wait
()
tensor
=
self
.
recv_store
[
tensor_id
]
if
tensor
is
not
None
:
if
isinstance
(
tensor
,
tuple
):
addr
,
dtype
,
shape
=
tensor
tensor
=
self
.
pool
.
load_tensor
(
addr
,
dtype
,
shape
,
self
.
device
)
else
:
self
.
buffer_size
-=
(
tensor
.
element_size
()
*
tensor
.
numel
())
else
:
duration
=
time
.
time
()
-
start_time
logger
.
warning
(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d"
,
remote_address
,
tensor_id
,
duration
*
1000
,
self
.
rank
)
return
tensor
# GET
if
remote_address
is
None
:
return
None
if
remote_address
not
in
self
.
socks
:
self
.
_create_connect
(
remote_address
)
sock
=
self
.
socks
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
data
=
{
"cmd"
:
"GET"
,
"tensor_id"
:
tensor_id
}
sock
.
send
(
msgpack
.
dumps
(
data
))
message
=
sock
.
recv
()
data
=
msgpack
.
loads
(
message
)
if
data
[
"ret"
]
!=
0
:
logger
.
warning
(
"🔴[GET]Recv From %s, tensor_id: %s, ret: %d"
,
remote_address
,
tensor_id
,
data
[
"ret"
])
return
None
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
_recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
return
tensor
def
_listen_for_requests
(
self
):
while
True
:
socks
=
dict
(
self
.
poller
.
poll
(
5000
))
if
self
.
router_socket
in
socks
:
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
logger
.
info
(
f
"unexpected message from
{
remote_address
.
decode
()
}
"
)
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
self
.
tensor_split_num
=
data
[
"tensor_split_num"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"0"
])
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
_recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
else
:
self
.
buffer_size
+=
tensor_size
except
torch
.
cuda
.
OutOfMemoryError
:
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"1"
])
tensor
=
None
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
)
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"PUT_NEW"
:
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
self
.
tensor_split_num
=
data
[
"tensor_split_num"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"0"
])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm
,
rank
=
self
.
comms
[
data
[
"pd_pair_id"
]]
self
.
_recv
(
comm
,
tensor
,
int
(
data
[
"comm_rank"
]),
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
else
:
self
.
buffer_size
+=
tensor_size
except
torch
.
cuda
.
OutOfMemoryError
:
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"1"
])
tensor
=
None
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
)
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"comm_init"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
int
(
data
[
"rank"
])
world_size
=
int
(
data
[
"world_size"
])
with
set_du_swift_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
world_size
,
unique_id
,
rank
)
self
.
comms
[
data
[
"pd_pair_id"
]]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
data
[
"pd_pair_id"
],
rank
)
elif
data
[
"cmd"
]
==
"req_to_transfer"
:
with
self
.
req_status_cv
:
assert
(
data
[
"request_id"
]
not
in
self
.
req_status
)
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
int
(
data
[
"dst_num"
]),
pd_pair_id
=
data
[
"pd_pair_id"
],
zmq_address_and_comm_rank
=
list
(
zip
(
data
[
"remote_address"
],
data
[
"remote_rank"
])))
self
.
req_status_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"req_not_to_transfer"
:
with
self
.
req_status_cv
:
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
0
)
self
.
req_status_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
with
self
.
send_store_cv
:
tensor
=
self
.
send_store
.
pop
(
tensor_id
,
None
)
if
tensor
is
not
None
:
data
=
{
"ret"
:
0
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
}
# LRU
self
.
send_store
[
tensor_id
]
=
tensor
self
.
_have_sent_tensor_id
(
tensor_id
)
else
:
data
=
{
"ret"
:
1
}
self
.
router_socket
.
send_multipart
(
[
remote_address
,
msgpack
.
dumps
(
data
)])
if
data
[
"ret"
]
==
0
:
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
else
:
logger
.
warning
(
"🚧Unexpected, Received message from %s, data:%s"
,
remote_address
,
data
)
def
_have_sent_tensor_id
(
self
,
tensor_id
:
str
):
request_id
=
tensor_id
.
split
(
'#'
)[
0
]
if
request_id
not
in
self
.
send_request_id_to_tensor_ids
:
self
.
send_request_id_to_tensor_ids
[
request_id
]
=
set
()
self
.
send_request_id_to_tensor_ids
[
request_id
].
add
(
tensor_id
)
def
_have_received_tensor_id
(
self
,
tensor_id
:
str
):
request_id
=
tensor_id
.
split
(
'#'
)[
0
]
if
request_id
not
in
self
.
recv_request_id_to_tensor_ids
:
self
.
recv_request_id_to_tensor_ids
[
request_id
]
=
set
()
self
.
recv_request_id_to_tensor_ids
[
request_id
].
add
(
tensor_id
)
def
_send_async
(
self
):
while
True
:
with
self
.
send_queue_cv
:
while
not
self
.
send_queue
:
self
.
send_queue_cv
.
wait
()
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
tensor_id
,
remote_address
,
kv_layer
,
slot_mapping
,
tbo_evt
=
self
.
send_queue
.
popleft
()
else
:
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
if
not
self
.
send_queue
:
self
.
send_queue_cv
.
notify
()
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
_send_kv_p2p_sync
(
tensor_id
,
kv_layer
,
slot_mapping
,
remote_address
)
else
:
if
self
.
multiple_machines
:
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
else
:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self
.
_send_sync_new
(
tensor_id
,
tensor
,
remote_address
)
def
wait_for_sent
(
self
):
if
self
.
send_type
==
"PUT_ASYNC"
:
start_time
=
time
.
time
()
with
self
.
send_queue_cv
:
while
self
.
send_queue
:
self
.
send_queue_cv
.
wait
()
duration
=
time
.
time
()
-
start_time
logger
.
debug
(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d"
,
duration
*
1000
,
self
.
rank
)
def
_send_kv_p2p_sync
(
self
,
tensor_id
:
str
,
kv_layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
remote_address
:
str
)
->
bool
:
if
remote_address
not
in
self
.
socks
:
self
.
_create_connect
(
remote_address
)
sock
=
self
.
socks
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
is_mla
=
(
kv_layer
.
ndim
==
3
)
hidden_dim
=
kv_layer
.
shape
[
-
1
]
if
self
.
p2p_async_buf
is
None
:
if
is_mla
:
self
.
p2p_async_buf
=
torch
.
empty
((
self
.
p2p_async_kv_tokens
,
hidden_dim
),
dtype
=
kv_layer
.
dtype
,
device
=
kv_layer
.
device
)
else
:
self
.
p2p_async_buf
=
torch
.
empty
((
2
,
self
.
p2p_async_kv_tokens
,
hidden_dim
),
dtype
=
kv_layer
.
dtype
,
device
=
kv_layer
.
device
)
pack_num
=
(
slot_mapping
.
shape
[
0
]
-
1
)
//
self
.
p2p_async_kv_tokens
+
1
self
.
tensor_split_num
=
pack_num
with
torch
.
cuda
.
stream
(
self
.
send_stream
):
for
pack_idx
in
range
(
pack_num
):
start
=
pack_idx
*
self
.
p2p_async_kv_tokens
end
=
min
((
pack_idx
+
1
)
*
self
.
p2p_async_kv_tokens
,
slot_mapping
.
shape
[
0
])
sub_index
=
slot_mapping
[
start
:
end
]
if
is_mla
:
num_pages
,
page_size
=
kv_layer
.
shape
[
0
],
kv_layer
.
shape
[
1
]
data
=
kv_layer
.
reshape
(
num_pages
*
page_size
,
-
1
)
torch
.
index_select
(
data
,
dim
=
0
,
index
=
sub_index
,
out
=
self
.
p2p_async_buf
[:
end
-
start
])
tx_shape
=
(
end
-
start
,
hidden_dim
)
else
:
num_pages
,
page_size
=
kv_layer
.
shape
[
1
],
kv_layer
.
shape
[
2
]
data
=
kv_layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)
torch
.
index_select
(
data
,
dim
=
1
,
index
=
sub_index
,
out
=
self
.
p2p_async_buf
[:,
:
end
-
start
])
tx_shape
=
(
2
,
end
-
start
,
hidden_dim
)
if
is_mla
:
send_tensor
=
self
.
p2p_async_buf
[:
end
-
start
]
else
:
send_tensor
=
self
.
p2p_async_buf
[:,
:
end
-
start
]
header
=
{
"cmd"
:
"PUT"
,
"tensor_id"
:
tensor_id
+
"#"
+
str
(
pack_idx
),
# 拼 pack_idx
"pack_idx"
:
pack_idx
,
"tensor_split_num"
:
pack_num
,
"shape"
:
tx_shape
,
"dtype"
:
str
(
kv_layer
.
dtype
).
replace
(
"torch."
,
""
)
}
sock
.
send
(
msgpack
.
dumps
(
header
))
response
=
sock
.
recv
()
if
response
!=
b
"0"
:
logger
.
error
(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s"
,
self
.
zmq_address
,
remote_address
,
rank
,
tuple
(
send_tensor
.
shape
),
send_tensor
.
element_size
()
*
send_tensor
.
numel
()
/
1024
**
3
,
response
.
decode
()
)
return
False
self
.
_send
(
comm
,
send_tensor
,
rank
^
1
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_have_sent_tensor_id
(
tensor_id
)
return
True
def
_send_sync_new
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
RemoteAddr
]
=
None
,
)
->
bool
:
if
remote_address
is
None
:
return
False
if
remote_address
.
zmq_address
not
in
self
.
socks
:
# logger.info(f"""=============xiabo remote_address.zmq_address:{remote_address.zmq_address}""")
self
.
_create_connect_new
(
remote_address
.
zmq_address
)
sock
=
self
.
socks
[
remote_address
.
zmq_address
]
comm
,
rank
=
self
.
comms
[
remote_address
.
pd_pair_id
]
data
=
{
"cmd"
:
"PUT_NEW"
,
"tensor_id"
:
tensor_id
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
),
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
logger
.
info
(
f
"""_send_sync_new:
{
data
}
"""
)
sock
.
send
(
msgpack
.
dumps
(
data
))
response
=
sock
.
recv
()
if
response
!=
b
"0"
:
logger
.
error
(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s"
,
self
.
zmq_address
,
remote_address
.
zmq_address
,
rank
,
data
,
tensor
.
shape
,
tensor
.
element_size
()
*
tensor
.
numel
()
/
1024
**
3
,
response
.
decode
())
return
False
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
remote_address
.
comm_rank
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_have_sent_tensor_id
(
tensor_id
)
return
True
def
_send_sync
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
)
->
bool
:
if
remote_address
is
None
:
return
False
if
remote_address
not
in
self
.
socks
:
self
.
_create_connect
(
remote_address
)
sock
=
self
.
socks
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
data
=
{
"cmd"
:
"PUT"
,
"tensor_id"
:
tensor_id
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
}
sock
.
send
(
msgpack
.
dumps
(
data
))
response
=
sock
.
recv
()
if
response
!=
b
"0"
:
logger
.
error
(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s"
,
self
.
zmq_address
,
remote_address
,
rank
,
data
,
tensor
.
shape
,
tensor
.
element_size
()
*
tensor
.
numel
()
/
1024
**
3
,
response
.
decode
())
return
False
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_have_sent_tensor_id
(
tensor_id
)
return
True
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
],
forward_context
:
"ForwardContext"
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
requests_to_release
:
list
[
str
]
=
[]
with
self
.
pending_queue_cv
:
for
request_id
,
release
in
self
.
requests_to_release
.
items
():
if
release
:
requests_to_release
.
append
(
request_id
)
self
.
requests_to_release
.
pop
(
request_id
)
for
request_id
in
finished_req_ids
:
with
self
.
pending_queue_cv
:
if
request_id
in
self
.
pending_queue
:
self
.
requests_to_release
[
request_id
]
=
False
logger
.
info
(
"[%d] pending request: %s"
,
self
.
rank
,
request_id
)
continue
requests_to_release
.
append
(
request_id
)
for
request_id
in
requests_to_release
:
ids
=
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
set
())
with
self
.
recv_store_cv
:
for
tensor_id
in
ids
:
tensor
=
self
.
recv_store
.
pop
(
tensor_id
,
None
)
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
pool
.
free
(
addr
)
self
.
send_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending
:
set
[
str
]
=
set
()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving
:
set
[
str
]
=
set
()
if
self
.
kv_cache_layer_num
==
0
:
for
layer_name
in
forward_context
.
no_compile_layers
:
layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
getattr
(
layer
,
'kv_cache'
,
None
)
if
kv_cache
is
None
:
continue
self
.
kv_cache_layer_num
+=
1
with
self
.
recv_store_cv
:
for
req
in
self
.
recv_request_id_to_tensor_ids
:
if
len
(
self
.
recv_request_id_to_tensor_ids
[
req
])
==
self
.
kv_cache_layer_num
:
finished_recving
.
add
(
req
)
return
finished_sending
or
None
,
finished_recving
or
None
def
_ping
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"zmq_address"
:
self
.
zmq_address
}
while
True
:
sock
.
send
(
msgpack
.
dumps
(
data
))
time
.
sleep
(
3
)
def
_ping_new
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
if
self
.
rank
==
0
:
data
=
{
"type"
:
"P_init"
if
self
.
config
.
is_kv_producer
else
"D_init"
,
"http_address"
:
self
.
http_address
,
"zmq_address"
:
self
.
zmq_address
,
"dp_size"
:
self
.
dp_size
,
"pp_size"
:
self
.
pp_size
,
"tp_size"
:
self
.
tp_size
}
# logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
data
=
{
"type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"dp_rank"
:
self
.
dp_rank
,
"pp_rank"
:
self
.
pp_rank
,
"tp_rank"
:
self
.
tp_rank
,
"zmq_address"
:
self
.
zmq_address
}
# while True:
# logger.info(f"""_ping data:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
# time.sleep(3)
def
_send
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
with
torch
.
cuda
.
stream
(
stream
):
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
stream
.
synchronize
()
def
_recv
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
with
torch
.
cuda
.
stream
(
stream
):
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
stream
.
synchronize
()
def
close
(
self
)
->
None
:
self
.
_listener_thread
.
join
()
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
.
join
()
self
.
_pending_check_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
def
get_pp_indices_d
(
self
,
num_hidden_layers
:
int
,
pp_rank
:
int
,
pp_size
:
int
)
->
tuple
[
int
,
int
]:
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION_D
if
partition_list_str
is
not
None
:
try
:
partitions
=
[
int
(
layer
)
for
layer
in
partition_list_str
.
split
(
","
)
]
except
ValueError
as
err
:
raise
ValueError
(
"Invalid partition string: {}"
.
format
(
partition_list_str
))
from
err
if
len
(
partitions
)
!=
pp_size
:
raise
ValueError
(
f
"
{
len
(
partitions
)
=
}
does not match
{
pp_size
=
}
."
)
if
sum
(
partitions
)
!=
num_hidden_layers
:
raise
ValueError
(
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
else
:
layers_per_partition
=
num_hidden_layers
//
pp_size
partitions
=
[
layers_per_partition
for
_
in
range
(
pp_size
)]
if
remaining_layers
:
=
num_hidden_layers
%
pp_size
:
for
i
in
range
(
2
,
remaining_layers
+
2
):
partitions
[
-
i
]
+=
1
logger
.
info
(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable"
,
","
.
join
(
str
(
p
)
for
p
in
partitions
))
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
return
(
start_layer
,
end_layer
)
def
compute_remote_pp_rank
(
self
,
layer_name
:
str
)
->
int
:
current_layer_idx
=
extract_layer_index
(
layer_name
)
for
d_pp_rank
in
range
(
self
.
remote_pp_size
):
start
,
end
=
self
.
get_pp_indices_d
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if
(
current_layer_idx
==
self
.
total_num_hidden_layers
):
return
self
.
remote_pp_size
-
1
if
start
<=
current_layer_idx
<
end
:
return
d_pp_rank
return
-
1
@
staticmethod
def
get_tensor_id
(
request_id
:
str
,
layer_name
:
str
)
->
str
:
return
request_id
+
"#"
+
layer_name
@
staticmethod
def
parse_request_id
(
request_id
:
str
,
is_prefill
=
True
)
->
tuple
[
str
,
int
]:
# Regular expression to match the string hostname and integer port
if
is_prefill
:
pattern
=
r
"___decode_addr_(.*):(\d+)"
else
:
pattern
=
r
"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match
=
regex
.
search
(
pattern
,
request_id
)
if
match
:
# Extract the ranks
ip
=
match
.
group
(
1
)
port
=
int
(
match
.
group
(
2
))
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
vllm/envs.py
View file @
61ba33d5
...
...
@@ -1841,6 +1841,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"USE_FUSED_RMS_QUANT"
,
"0"
))),
#vllm use dp connector
"VLLM_USE_DP_CONNECTOR"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DP_CONNECTOR"
,
"0"
))),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_P2P_ASYNC"
,
"0"
))),
...
...
vllm/v1/core/sched/scheduler.py
View file @
61ba33d5
...
...
@@ -121,7 +121,7 @@ class Scheduler(SchedulerInterface):
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
,
kv_cache_config
=
self
.
kv_cache_config
,
)
dp_rank
=
self
.
parallel_config
.
data_parallel_rank
)
if
self
.
log_stats
:
self
.
connector_prefix_cache_stats
=
PrefixCacheStats
()
kv_load_failure_policy
=
(
...
...
@@ -556,6 +556,12 @@ class Scheduler(SchedulerInterface):
+
len
(
scheduled_running_reqs
)
>=
max_batch_running
):
break
request
=
self
.
waiting
.
peek_request
()
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
\
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
and
\
envs
.
VLLM_USE_DP_CONNECTOR
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
...
vllm/v1/engine/core.py
View file @
61ba33d5
...
...
@@ -66,6 +66,7 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
compute_iteration_details
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
...
...
@@ -1155,6 +1156,11 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
if
isinstance
(
request
,
tuple
)
and
self
.
scheduler
.
connector
is
not
None
\
and
envs
.
VLLM_USE_DP_CONNECTOR
:
req
,
_
=
request
if
request_type
==
EngineCoreRequestType
.
ADD
:
self
.
scheduler
.
connector
.
register_req
(
req
.
request_id
)
def
process_output_sockets
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment