Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d48d8d6d
"vllm/vscode:/vscode.git/clone" did not exist on "cd8dfc6dfc832fc4bc8ea0c9b01ad92d677c75bb"
Commit
d48d8d6d
authored
Feb 25, 2026
by
Your Name
Browse files
[PD][Feat]支持pd模式下dp并行
parent
ffd26247
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
391 additions
and
104 deletions
+391
-104
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
...gated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
+155
-33
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+2
-1
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+83
-36
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
...ributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
+136
-32
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+12
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-0
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py
View file @
d48d8d6d
...
@@ -13,6 +13,8 @@ from typing import Any
...
@@ -13,6 +13,8 @@ from typing import Any
from
quart
import
Quart
,
make_response
,
request
from
quart
import
Quart
,
make_response
,
request
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
import
time
import
asyncio
from
collections
import
deque
,
defaultdict
from
collections
import
deque
,
defaultdict
import
logging
import
logging
logging
.
basicConfig
(
logging
.
basicConfig
(
...
@@ -21,13 +23,13 @@ logging.basicConfig(
...
@@ -21,13 +23,13 @@ logging.basicConfig(
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
#
@dataclass
@
dataclass
#
class Request:
class
Request
:
#
request_id: str
request_id
:
str
#
p_http_address: str = ""
p_http_address
:
str
=
""
#
p_dp_rank: int = -1
p_dp_rank
:
int
=
-
1
#
d_http_address: str = ""
d_http_address
:
str
=
""
#
d_dp_rank: int = -1
d_dp_rank
:
int
=
-
1
@
dataclass
@
dataclass
class
Instance
:
class
Instance
:
...
@@ -60,17 +62,16 @@ class Instance:
...
@@ -60,17 +62,16 @@ class Instance:
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
all_ranks_ready
=
world_size
and
inited_rank
==
world_size
if
self
.
ins_type
==
"P"
:
if
self
.
ins_type
==
"P"
:
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
logger
.
info
(
f
"""[Router] P is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
# return all_ranks_ready and self.p_unique_id != b""
return
all_ranks_ready
return
all_ranks_ready
else
:
else
:
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
logger
.
info
(
f
"""[Router] D is_ready? :
{
self
.
http_address
}
world_size =
{
world_size
}
inited_rank =
{
inited_rank
}
"""
)
return
all_ranks_ready
return
all_ranks_ready
count
=
0
count
=
0
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
prefill_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
decode_instances
:
dict
[
str
,
Instance
]
=
{}
running_requests
:
dict
[
str
,
Request
]
=
{}
healthy_instances
:
dict
[
str
,
float
]
=
{}
pending_prefill_ins
:
list
[
str
]
=
[]
pending_prefill_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
pending_decode_ins
:
list
[
str
]
=
[]
...
@@ -80,10 +81,13 @@ ready_decode_ins: list[str] = []
...
@@ -80,10 +81,13 @@ ready_decode_ins: list[str] = []
pd_pair
:
dict
[
str
,
bytes
]
=
{}
pd_pair
:
dict
[
str
,
bytes
]
=
{}
router_nccl
=
NCCLLibrary
()
router_nccl
=
NCCLLibrary
()
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
instance_cv
=
threading
.
Condition
()
instance_cv
=
threading
.
Condition
()
request_cv
=
threading
.
Condition
()
health_cv
=
threading
.
Condition
()
request_queue_cv
=
threading
.
Condition
()
request_queue
:
deque
[
list
[
Any
]]
=
deque
()
sock_cache
:
dict
[
str
,
Any
]
=
{}
sock_cache
:
dict
[
str
,
Any
]
=
{}
def
_listen_for_register
(
poller
,
router_socket
):
def
_listen_for_register
(
poller
,
router_socket
):
...
@@ -149,6 +153,35 @@ def _listen_for_register(poller, router_socket):
...
@@ -149,6 +153,35 @@ def _listen_for_register(poller, router_socket):
pending_decode_ins
.
append
(
d_instance
.
http_address
)
pending_decode_ins
.
append
(
d_instance
.
http_address
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
logger
.
info
(
f
"""[Router] pending_decode_ins appended
{
d_instance
.
http_address
}
ZMQ:
{
d_instance
.
zmq_address
}
"""
)
instance_cv
.
notify
()
instance_cv
.
notify
()
elif
data
[
"type"
]
==
"heartbeat"
:
global
healthy_instances
global
health_cv
with
health_cv
:
healthy_instances
[
data
[
"http_address"
]]
=
time
.
time
()
elif
data
[
"type"
]
==
"Req"
:
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global
running_requests
global
request_cv
with
request_cv
:
if
data
[
"request_id"
]
in
running_requests
:
request
=
running_requests
[
data
[
"request_id"
]]
if
data
[
"instance_type"
]
==
"P"
:
request
.
p_http_address
=
data
[
"http_address"
]
request
.
p_dp_rank
=
int
(
data
[
"dp_rank"
])
elif
data
[
"instance_type"
]
==
"D"
:
request
.
d_http_address
=
data
[
"http_address"
]
request
.
d_dp_rank
=
int
(
data
[
"dp_rank"
])
assert
(
request
.
p_dp_rank
>=
0
and
request
.
d_dp_rank
>=
0
)
with
request_queue_cv
:
request_queue
.
append
(
request
)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv
.
notify
()
else
:
if
data
[
"instance_type"
]
==
"P"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
p_http_address
=
data
[
"http_address"
],
p_dp_rank
=
int
(
data
[
"dp_rank"
]))
elif
data
[
"instance_type"
]
==
"D"
:
running_requests
[
data
[
"request_id"
]]
=
Request
(
request_id
=
data
[
"request_id"
],
d_http_address
=
data
[
"http_address"
],
d_dp_rank
=
int
(
data
[
"dp_rank"
]))
else
:
else
:
print
(
print
(
"Unexpected, Received message from %s, data: %s"
,
"Unexpected, Received message from %s, data: %s"
,
...
@@ -157,6 +190,9 @@ def _listen_for_register(poller, router_socket):
...
@@ -157,6 +190,9 @@ def _listen_for_register(poller, router_socket):
)
)
zmq_context
=
None
zmq_context
=
None
tp_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
str
]]]
=
{}
tp_comm_mapping_of_pd_pair
:
dict
[
str
,
dict
[
int
,
list
[
int
]]]
=
{}
active_p_tp_rank_of_pd_pair
:
dict
[
str
,
set
[
int
]]
=
{}
def
start_service_discovery
(
hostname
,
port
):
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
if
not
hostname
:
...
@@ -180,6 +216,91 @@ def start_service_discovery(hostname, port):
...
@@ -180,6 +216,91 @@ def start_service_discovery(hostname, port):
_listener_thread
.
start
()
_listener_thread
.
start
()
return
_listener_thread
return
_listener_thread
def
dispatch_to_P
(
request
:
Request
):
global
prefill_instances
global
decode_instances
p_ins
=
prefill_instances
[
request
.
p_http_address
]
d_ins
=
decode_instances
[
request
.
d_http_address
]
global
zmq_context
global
sock_cache
pd_pair_id
=
p_ins
.
http_address
+
"_"
+
d_ins
.
http_address
p_dp_rank
=
request
.
p_dp_rank
d_dp_rank
=
request
.
d_dp_rank
tp_dst_id
=
pd_pair_id
+
"_"
+
str
(
d_dp_rank
)
assert
(
d_ins
.
pp_size
==
1
)
d_pp_rank
=
0
global
tp_mapping_of_pd_pair
global
tp_comm_mapping_of_pd_pair
global
active_p_tp_rank_of_pd_pair
if
tp_dst_id
not
in
active_p_tp_rank_of_pd_pair
:
p_active_tp_rank
=
set
()
p_tp_rank_to_dst
:
dict
[
int
,
list
[
str
]]
=
defaultdict
(
list
)
p_tp_rank_to_dst_comm
:
dict
[
int
,
list
[
int
]]
=
defaultdict
(
list
)
for
d_tp_rank
in
range
(
d_ins
.
tp_size
):
p_tp_rank
=
d_tp_rank
%
p_ins
.
tp_size
p_active_tp_rank
.
add
(
p_tp_rank
)
p_tp_rank_to_dst
[
p_tp_rank
].
append
(
d_ins
.
rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
p_tp_rank_to_dst_comm
[
p_tp_rank
].
append
(
d_ins
.
comm_rank_table
[
d_dp_rank
][
d_pp_rank
][
d_tp_rank
])
tp_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
=
p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
=
p_active_tp_rank
p_active_tp_rank
=
active_p_tp_rank_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst
=
tp_mapping_of_pd_pair
[
tp_dst_id
]
p_tp_rank_to_dst_comm
=
tp_comm_mapping_of_pd_pair
[
tp_dst_id
]
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
for
p_tp_rank
in
p_active_tp_rank
:
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_to_transfer"
,
"request_id"
:
request
.
request_id
,
"dst_num"
:
len
(
p_tp_rank_to_dst
[
p_tp_rank
]),
"pd_pair_id"
:
pd_pair_id
,
"remote_address"
:
p_tp_rank_to_dst
[
p_tp_rank
],
"remote_rank"
:
p_tp_rank_to_dst_comm
[
p_tp_rank
],
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
logger
.
info
(
f
"""[Router] dispatch Request
{
request
.
request_id
}
[
{
p_dp_rank
}
,
{
p_pp_rank
}
,
{
p_tp_rank
}
] -> [
{
d_dp_rank
}
,
{
d_pp_rank
}
]"""
)
for
p_tp_rank
in
range
(
p_ins
.
tp_size
):
if
p_tp_rank
not
in
p_active_tp_rank
:
for
p_pp_rank
in
range
(
p_ins
.
pp_size
):
if
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
not
in
sock_cache
:
sock
=
zmq_context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
"router"
)
sock
.
connect
(
f
"tcp://
{
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]
}
"
)
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]]
=
sock
data
=
{
"cmd"
:
"req_not_to_transfer"
,
"request_id"
:
request
.
request_id
,
}
sock_cache
[
p_ins
.
rank_table
[
p_dp_rank
][
p_pp_rank
][
p_tp_rank
]].
send
(
msgpack
.
dumps
(
data
))
def
dp_dispatch
():
global
request_queue_cv
global
request_queue
while
True
:
with
request_queue_cv
:
while
not
request_queue
:
request_queue_cv
.
wait
()
request
=
request_queue
.
pop
()
dispatch_to_P
(
request
)
def
start_dp_dispatch
():
_thread
=
threading
.
Thread
(
target
=
dp_dispatch
,
daemon
=
True
)
_thread
.
start
()
return
_thread
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
...
@@ -206,14 +327,14 @@ async def forward_request(url, data, request_id):
...
@@ -206,14 +327,14 @@ async def forward_request(url, data, request_id):
yield
content
yield
content
def
unique_id_dispatch
(
prefill_instance
:
str
,
def
unique_id_dispatch
(
prefill_instance
:
Instance
,
decode_instance
:
str
)
:
decode_instance
:
Instance
)
:
global
zmq_context
global
zmq_context
global
sock_cache
global
sock_cache
global
router_nccl
global
router_nccl
global
pd_pair
global
pd_pair
pd_pair_id
=
prefill_instance
.
zmq
_address
+
"_"
+
decode_instance
.
zmq
_address
pd_pair_id
=
prefill_instance
.
http
_address
+
"_"
+
decode_instance
.
http
_address
if
pd_pair_id
in
pd_pair
:
if
pd_pair_id
in
pd_pair
:
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
logger
.
info
(
f
"""[Router] pd pair
{
pd_pair_id
}
already exist"""
)
...
@@ -320,35 +441,34 @@ async def handle_request():
...
@@ -320,35 +441,34 @@ async def handle_request():
global
count
global
count
global
prefill_instances
global
prefill_instances
global
prefill
_cv
global
instance
_cv
with
prefill
_cv
:
with
instance
_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
prefill_addr
,
prefill_instance
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
global
decode_instances
global
decode_cv
with
instance_cv
:
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
decode_addr
,
decode_instance
=
decode_list
[
count
%
len
(
decode_list
)]
print
(
global
pd_pair
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
if
prefill_instance
.
http_address
+
"_"
+
decode_instance
.
http_address
not
in
pd_pair
:
f
"ZMQ:
{
prefill_instance
.
zmq_address
}
] 👉 [HTTP:
{
decode_addr
}
, "
raise
RuntimeError
(
"Selected PD pair was not inited"
)
f
"ZMQ:
{
decode_instance
.
zmq_address
}
]"
logger
.
info
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, 👉 HTTP:
{
decode_addr
}
]"
)
)
count
+=
1
count
+=
1
request_id
=
(
request_id
=
f
"
{
random_uuid
()
}
"
f
"___prefill_addr_
{
prefill_instance
.
zmq_address
}
___decode_addr_"
f
"
{
decode_instance
.
zmq_address
}
_
{
random_uuid
()
}
"
)
# finish
prefill
async
def
run_
prefill
():
async
for
_
in
forward_request
(
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
):
):
continue
pass
prefill_task
=
asyncio
.
create_task
(
run_prefill
())
# return decode
# return decode
generator
=
forward_request
(
generator
=
forward_request
(
...
@@ -372,6 +492,8 @@ async def handle_request():
...
@@ -372,6 +492,8 @@ async def handle_request():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
t_1
=
start_pd_pair_init
()
t_1
=
start_pd_pair_init
()
t_2
=
start_dp_dispatch
()
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
t
.
join
()
t_1
.
join
()
t_1
.
join
()
t_2
.
join
()
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
d48d8d6d
...
@@ -54,6 +54,7 @@ class KVConnectorFactory:
...
@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls
,
cls
,
config
:
"VllmConfig"
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
,
)
->
KVConnectorBase_V1
:
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
...
@@ -81,7 +82,7 @@ class KVConnectorFactory:
...
@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
# We build separately to enforce strict separation
return
connector_cls
(
config
,
role
)
return
connector_cls
(
config
,
role
,
dp_rank
)
# Register various connectors here.
# Register various connectors here.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
d48d8d6d
...
@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
...
@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
import
zmq
import
msgpack
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
...
@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class
DuSwiftConnector
(
KVConnectorBase_V1
):
class
DuSwiftConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
dp_rank
:
int
=
-
1
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
...
@@ -158,9 +161,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -158,9 +161,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
print
(
f
"Error: Exception occurred while reading configuration file -
{
e
}
"
)
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
dp_rank
=
dp_rank
proxy_ip
=
self
.
config
.
get_from_extra_config
(
"proxy_ip"
,
""
)
proxy_port
=
self
.
config
.
get_from_extra_config
(
"proxy_port"
,
""
)
if
proxy_ip
==
""
or
proxy_port
==
""
:
self
.
proxy_address
=
""
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
self
.
context
=
zmq
.
Context
()
req_sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
req_sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
f
"
{
self
.
http_address
}
_rank
{
self
.
dp_rank
}
"
)
req_sock
.
connect
(
f
"tcp://
{
self
.
proxy_address
}
"
)
self
.
req_sock
=
req_sock
def
get_ip_value
(
self
,
key
):
def
get_ip_value
(
self
,
key
):
return
self
.
ip_map
.
get
(
key
)
return
self
.
ip_map
.
get
(
key
)
def
register_req
(
self
,
request_id
:
str
)
:
data
=
{
"type"
:
"Req"
,
"instance_type"
:
"P"
if
self
.
config
.
is_kv_producer
else
"D"
,
"http_address"
:
self
.
http_address
,
"request_id"
:
request_id
,
"dp_rank"
:
self
.
dp_rank
}
self
.
req_sock
.
send
(
msgpack
.
dumps
(
data
))
# ==============================
# ==============================
# Worker-side methods
# Worker-side methods
...
@@ -438,43 +471,57 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -438,43 +471,57 @@ class DuSwiftConnector(KVConnectorBase_V1):
else
:
else
:
for
request
in
connector_metadata
.
requests
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
#
ip, port = self.parse_request_id(request_id, True)
p_ip
,
p_port
=
self
.
parse_request_id
(
request_id
,
False
)
#
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
#
remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
# pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
# ) % self.parallel_config.pipeline_parallel_size
if
(
self
.
multiple_machines_p
and
self
.
multiple_machines_d
):
# if (self.multiple_machines_p and self.multiple_machines_d):
ip_second
=
self
.
get_ip_value
(
ip
)
# ip_second = self.get_ip_value(ip)
if
(
self
.
pp_size
==
1
):
# if (self.pp_size == 1):
if
self
.
_rank
<
8
:
# if self._rank < 8:
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache
,
remote_address
)
# kv_cache, remote_address)
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
+
8
))
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif
(
self
.
pp_size
==
2
):
# elif (self.pp_size == 2):
if
(
pp_rank
==
0
):
# if (pp_rank == 0):
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache
,
remote_address
)
# kv_cache, remote_address)
else
:
# else:
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
# kv_cache, str(ip_second) + ":" + str(port + self._rank))
else
:
# else:
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
# logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
# elif (self.multiple_machines_p and not self.multiple_machines_d):
if
(
self
.
pp_size
==
2
):
# if (self.pp_size == 2):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
# remote_address = ip + ":" + str(port + self._tp_rank)
pending
=
False
with
self
.
du_swift_engine
.
req_status_cv
:
if
request_id
not
in
self
.
du_swift_engine
.
req_status
:
pending
=
True
if
pending
:
self
.
du_swift_engine
.
pending_tensor
(
request_id
,
layer_name
,
kv_cache
)
logger
.
info
(
"[%d] pending for request: %s layer: %s"
,
self
.
_rank
,
request_id
,
layer_name
)
else
:
req_data
=
self
.
du_swift_engine
.
req_status
[
request_id
]
assert
(
req_data
.
dst_num
==
len
(
req_data
.
zmq_address_and_comm_rank
))
for
i
in
range
(
req_data
.
dst_num
):
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
du_swift_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_addr
)
else
:
# kv_cache, remote_address)
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
# else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
#
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
#
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self
.
du_swift_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
#
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla
)
#
is_mla)
# if (self.pp_size == 1):
# if (self.pp_size == 1):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# kv_cache, remote_address)
...
@@ -498,8 +545,8 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -498,8 +545,8 @@ class DuSwiftConnector(KVConnectorBase_V1):
# kv_cache, remote_address)
# kv_cache, remote_address)
# else:
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
#
else:
logger
.
error
(
"Error: not support!!!!!!"
)
#
logger.error("Error: not support!!!!!!")
def
wait_for_save
(
self
):
def
wait_for_save
(
self
):
pass
pass
# if self.is_producer:
# if self.is_producer:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
View file @
d48d8d6d
...
@@ -6,9 +6,10 @@ import os
...
@@ -6,9 +6,10 @@ import os
import
threading
import
threading
import
time
import
time
import
typing
import
typing
from
collections
import
deque
from
collections
import
deque
,
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
dataclasses
import
dataclass
,
field
import
msgpack
import
msgpack
import
torch
import
torch
...
@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
...
@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os
.
environ
.
pop
(
var
,
None
)
os
.
environ
.
pop
(
var
,
None
)
@
dataclass
class
ReqKVDest
:
dst_num
:
int
=
0
pd_pair_id
:
str
=
""
zmq_address_and_comm_rank
:
list
[
tuple
[
str
,
int
]]
=
field
(
default_factory
=
list
)
@
dataclass
@
dataclass
class
RemoteAddr
:
class
RemoteAddr
:
pd_pair_id
:
str
=
""
pd_pair_id
:
str
=
""
...
@@ -125,6 +132,10 @@ class DuSwiftEngine:
...
@@ -125,6 +132,10 @@ class DuSwiftEngine:
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
self
.
multiple_machines
=
self
.
config
.
get_from_extra_config
(
self
.
multiple_machines
=
self
.
config
.
get_from_extra_config
(
"enable_multiple_machines"
,
False
)
"enable_multiple_machines"
,
False
)
self
.
instance_ip
=
self
.
config
.
get_from_extra_config
(
"instance_ip"
,
None
)
if
self
.
instance_ip
:
self
.
multiple_machines
=
False
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
raise
ValueError
(
"Port cannot be 0"
)
...
@@ -135,6 +146,11 @@ class DuSwiftEngine:
...
@@ -135,6 +146,11 @@ class DuSwiftEngine:
self
.
zmq_address
=
f
"
{
self
.
_hostname
}
:
{
self
.
_port
}
"
self
.
zmq_address
=
f
"
{
self
.
_hostname
}
:
{
self
.
_port
}
"
# The `http_port` must be consistent with the port of OpenAI.
# The `http_port` must be consistent with the port of OpenAI.
if
self
.
instance_ip
:
self
.
http_address
=
(
f
"
{
self
.
config
.
kv_connector_extra_config
[
'instance_ip'
]
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
else
:
self
.
http_address
=
(
self
.
http_address
=
(
f
"
{
self
.
_hostname
}
:"
f
"
{
self
.
_hostname
}
:"
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
f
"
{
self
.
config
.
kv_connector_extra_config
[
'http_port'
]
}
"
)
...
@@ -148,16 +164,27 @@ class DuSwiftEngine:
...
@@ -148,16 +164,27 @@ class DuSwiftEngine:
else
:
else
:
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
proxy_address
=
proxy_ip
+
":"
+
proxy_port
self
.
kv_cache_layer_num
=
0
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
router_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
router_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
router_socket
.
setsockopt
(
zmq
.
RCVHWM
,
10000
)
self
.
router_socket
.
setsockopt
(
zmq
.
SNDHWM
,
5000
)
self
.
router_socket
.
setsockopt
(
zmq
.
LINGER
,
0
)
self
.
router_socket
.
setsockopt
(
zmq
.
ROUTER_MANDATORY
,
1
)
self
.
router_socket
.
setsockopt
(
zmq
.
TCP_KEEPALIVE
,
1
)
self
.
router_socket
.
bind
(
f
"tcp://
{
self
.
zmq_address
}
"
)
self
.
router_socket
.
bind
(
f
"tcp://
{
self
.
zmq_address
}
"
)
self
.
poller
=
zmq
.
Poller
()
self
.
poller
=
zmq
.
Poller
()
self
.
poller
.
register
(
self
.
router_socket
,
zmq
.
POLLIN
)
self
.
poller
.
register
(
self
.
router_socket
,
zmq
.
POLLIN
)
self
.
req_status
:
dict
[
str
,
ReqKVDest
]
=
{}
self
.
req_status_cv
=
threading
.
Condition
()
self
.
send_store_cv
=
threading
.
Condition
()
self
.
send_store_cv
=
threading
.
Condition
()
self
.
send_queue_cv
=
threading
.
Condition
()
self
.
send_queue_cv
=
threading
.
Condition
()
self
.
recv_store_cv
=
threading
.
Condition
()
self
.
recv_store_cv
=
threading
.
Condition
()
self
.
pending_queue_cv
=
threading
.
Condition
()
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
...
@@ -181,11 +208,16 @@ class DuSwiftEngine:
...
@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
# tensor_id: torch.Tensor
self
.
send_queue
:
deque
[
list
[
Any
]]
=
deque
()
self
.
send_queue
:
deque
[
list
[
Any
]]
=
deque
()
self
.
pending_queue
:
dict
[
str
,
list
[
list
[
Any
]]]
=
defaultdict
(
list
)
self
.
requests_to_release
:
dict
[
str
,
bool
]
=
{}
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_send_async
,
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_send_async
,
daemon
=
True
)
daemon
=
True
)
self
.
_send_thread
.
start
()
self
.
_send_thread
.
start
()
self
.
_pending_check_thread
=
threading
.
Thread
(
target
=
self
.
_pending_check
,
daemon
=
True
)
self
.
_pending_check_thread
.
start
()
# tensor_id: torch.Tensor/(addr, dtype, shape)
# tensor_id: torch.Tensor/(addr, dtype, shape)
self
.
recv_store
:
dict
[
str
,
Any
]
=
{}
self
.
recv_store
:
dict
[
str
,
Any
]
=
{}
...
@@ -328,7 +360,7 @@ class DuSwiftEngine:
...
@@ -328,7 +360,7 @@ class DuSwiftEngine:
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
st
r
]
=
None
,
remote_address
:
typing
.
Optional
[
RemoteAdd
r
]
=
None
,
tbo_evt
=
None
,
tbo_evt
=
None
,
)
->
bool
:
)
->
bool
:
if
remote_address
is
None
:
if
remote_address
is
None
:
...
@@ -356,7 +388,7 @@ class DuSwiftEngine:
...
@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger
.
info
(
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
send_store
[
tensor_id
]
=
tensor
...
@@ -364,7 +396,7 @@ class DuSwiftEngine:
...
@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger
.
debug
(
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
remote_address
.
zmq_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
...
@@ -417,6 +449,55 @@ class DuSwiftEngine:
...
@@ -417,6 +449,55 @@ class DuSwiftEngine:
return
True
return
True
def
pending_tensor
(
self
,
reuqest_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
tbo_evt
=
None
,
)
->
bool
:
with
self
.
pending_queue_cv
:
self
.
pending_queue
[
reuqest_id
].
append
([
layer_name
,
tensor
,
tbo_evt
])
self
.
pending_queue_cv
.
notify
()
return
True
def
unpending_tensor
(
self
,
request_id
:
str
,
req_data
:
ReqKVDest
,
)
->
bool
:
with
self
.
pending_queue_cv
:
tensor_list
=
self
.
pending_queue
.
pop
(
request_id
)
if
request_id
in
self
.
requests_to_release
:
self
.
requests_to_release
[
request_id
]
=
True
logger
.
info
(
"[%d] unpending request: %s"
,
self
.
rank
,
request_id
)
if
req_data
.
dst_num
<=
0
:
return
False
for
layer_name
,
tensor
,
tbo_evt
in
tensor_list
:
for
i
in
range
(
req_data
.
dst_num
)
:
remote_addr
=
RemoteAddr
(
req_data
.
pd_pair_id
,
*
(
req_data
.
zmq_address_and_comm_rank
[
i
]))
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
p2p_async_send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
,
tbo_evt
)
else
:
self
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
tensor
,
remote_addr
)
return
True
def
_pending_check
(
self
)
:
while
True
:
with
self
.
pending_queue_cv
:
while
not
self
.
pending_queue
:
self
.
pending_queue_cv
.
wait
()
pending_queue
=
self
.
pending_queue
.
copy
()
for
request_id
in
pending_queue
:
with
self
.
req_status_cv
:
if
request_id
not
in
self
.
req_status
:
continue
req_data
=
self
.
req_status
[
request_id
]
assert
(
len
(
req_data
.
zmq_address_and_comm_rank
)
==
req_data
.
dst_num
)
self
.
unpending_tensor
(
request_id
,
req_data
)
def
recv_tensor
(
def
recv_tensor
(
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
...
@@ -475,22 +556,12 @@ class DuSwiftEngine:
...
@@ -475,22 +556,12 @@ class DuSwiftEngine:
def
_listen_for_requests
(
self
):
def
_listen_for_requests
(
self
):
while
True
:
while
True
:
socks
=
dict
(
self
.
poller
.
poll
())
socks
=
dict
(
self
.
poller
.
poll
(
5000
))
if
self
.
router_socket
in
socks
:
if
self
.
router_socket
in
socks
:
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
if
data
[
"cmd"
]
==
"NEW"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
logger
.
info
(
f
"unexpected message from
{
remote_address
.
decode
()
}
"
)
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
1
with
set_du_swift_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
.
decode
()]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
rank
)
elif
data
[
"cmd"
]
==
"PUT"
:
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
tensor_id
=
data
[
"tensor_id"
]
if
"tensor_split_num"
in
data
:
if
"tensor_split_num"
in
data
:
...
@@ -577,6 +648,15 @@ class DuSwiftEngine:
...
@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger
.
info
(
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
data
[
"pd_pair_id"
],
rank
)
self
.
zmq_address
,
data
[
"pd_pair_id"
],
rank
)
elif
data
[
"cmd"
]
==
"req_to_transfer"
:
with
self
.
req_status_cv
:
assert
(
data
[
"request_id"
]
not
in
self
.
req_status
)
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
int
(
data
[
"dst_num"
]),
pd_pair_id
=
data
[
"pd_pair_id"
],
zmq_address_and_comm_rank
=
list
(
zip
(
data
[
"remote_address"
],
data
[
"remote_rank"
])))
self
.
req_status_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"req_not_to_transfer"
:
with
self
.
req_status_cv
:
self
.
req_status
[
data
[
"request_id"
]]
=
ReqKVDest
(
dst_num
=
0
)
self
.
req_status_cv
.
notify_all
()
elif
data
[
"cmd"
]
==
"GET"
:
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
tensor_id
=
data
[
"tensor_id"
]
with
self
.
send_store_cv
:
with
self
.
send_store_cv
:
...
@@ -814,20 +894,31 @@ class DuSwiftEngine:
...
@@ -814,20 +894,31 @@ class DuSwiftEngine:
"""
"""
# Clear the buffer upon request completion.
# Clear the buffer upon request completion.
requests_to_release
:
list
[
str
]
=
[]
with
self
.
pending_queue_cv
:
for
request_id
,
release
in
self
.
requests_to_release
.
items
():
if
release
:
requests_to_release
.
append
(
request_id
)
self
.
requests_to_release
.
pop
(
request_id
)
for
request_id
in
finished_req_ids
:
for
request_id
in
finished_req_ids
:
for
layer_name
in
forward_context
.
no_compile_layers
:
with
self
.
pending_queue_cv
:
tensor_id
=
request_id
+
"#"
+
layer_name
if
request_id
in
self
.
pending_queue
:
if
tensor_id
in
self
.
recv_store
:
self
.
requests_to_release
[
request_id
]
=
False
logger
.
info
(
"[%d] pending request: %s"
,
self
.
rank
,
request_id
)
continue
requests_to_release
.
append
(
request_id
)
for
request_id
in
requests_to_release
:
ids
=
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
set
())
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
for
tensor_id
in
ids
:
tensor
=
self
.
recv_store
.
pop
(
tensor_id
,
None
)
tensor
=
self
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
send_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
addr
,
_
,
_
=
tensor
self
.
pool
.
free
(
addr
)
self
.
pool
.
free
(
addr
)
self
.
send_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
# TODO:Retrieve requests that have already sent the KV cache.
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending
:
set
[
str
]
=
set
()
finished_sending
:
set
[
str
]
=
set
()
...
@@ -835,6 +926,19 @@ class DuSwiftEngine:
...
@@ -835,6 +926,19 @@ class DuSwiftEngine:
# TODO:Retrieve requests that have already received the KV cache.
# TODO:Retrieve requests that have already received the KV cache.
finished_recving
:
set
[
str
]
=
set
()
finished_recving
:
set
[
str
]
=
set
()
if
self
.
kv_cache_layer_num
==
0
:
for
layer_name
in
forward_context
.
no_compile_layers
:
layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
getattr
(
layer
,
'kv_cache'
,
None
)
if
kv_cache
is
None
:
continue
self
.
kv_cache_layer_num
+=
1
with
self
.
recv_store_cv
:
for
req
in
self
.
recv_request_id_to_tensor_ids
:
if
len
(
self
.
recv_request_id_to_tensor_ids
[
req
])
==
self
.
kv_cache_layer_num
:
finished_recving
.
add
(
req
)
return
finished_sending
or
None
,
finished_recving
or
None
return
finished_sending
or
None
,
finished_recving
or
None
def
_ping
(
self
):
def
_ping
(
self
):
...
@@ -910,7 +1014,7 @@ class DuSwiftEngine:
...
@@ -910,7 +1014,7 @@ class DuSwiftEngine:
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
self
.
_listener_thread
.
join
()
self
.
_listener_thread
.
join
()
if
self
.
send_type
==
"PUT_ASYNC"
:
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_
s
end_thread
.
join
()
self
.
_
p
end
ing_check
_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
self
.
_ping_thread
.
join
()
...
...
vllm/v1/core/sched/scheduler.py
View file @
d48d8d6d
...
@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
...
@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
,
dp_rank
=
self
.
parallel_config
.
data_parallel_rank
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
self
.
kv_events_config
,
...
@@ -380,6 +381,10 @@ class Scheduler(SchedulerInterface):
...
@@ -380,6 +381,10 @@ class Scheduler(SchedulerInterface):
if
request
.
is_finished
():
if
request
.
is_finished
():
self
.
waiting
.
pop_request
()
self
.
waiting
.
pop_request
()
continue
continue
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
@@ -674,6 +679,11 @@ class Scheduler(SchedulerInterface):
...
@@ -674,6 +679,11 @@ class Scheduler(SchedulerInterface):
break
break
request
=
self
.
waiting
.
peek_request
()
request
=
self
.
waiting
.
peek_request
()
if
self
.
connector
and
not
self
.
connector
.
is_producer
and
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
# KVTransfer: skip request if still waiting for remote kvs.
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
@@ -1326,7 +1336,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1326,7 +1336,7 @@ class Scheduler(SchedulerInterface):
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
spec_token_ids
is
not
None
and
(
self
.
connector
is
None
or
not
self
.
connector
.
is_producer
)
:
if
self
.
structured_output_manager
.
should_advance
(
request
):
if
self
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
# Needs to happen after new_token_ids are accepted.
...
...
vllm/v1/engine/core.py
View file @
d48d8d6d
...
@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
...
@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
if
isinstance
(
request
,
EngineCoreRequest
)
and
self
.
scheduler
.
connector
is
not
None
:
if
request_type
==
EngineCoreRequestType
.
ADD
:
self
.
scheduler
.
connector
.
register_req
(
request
.
request_id
)
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
coord_output_path
:
Optional
[
str
],
coord_output_path
:
Optional
[
str
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment