Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e013fa3
Unverified
Commit
1e013fa3
authored
Apr 22, 2025
by
Nick Hill
Committed by
GitHub
Apr 22, 2025
Browse files
[V1][DP] More robust DP/EP dummy request coordination (#16277)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
bc7c4d20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
57 deletions
+94
-57
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+2
-2
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+12
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+42
-21
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+38
-31
No files found.
tests/v1/test_async_llm_dp.py
View file @
1e013fa3
...
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for
_
in
range
(
10
):
if
core_client
.
num_
engines_running
==
0
:
if
not
core_client
.
engines_running
:
break
await
asyncio
.
sleep
(
0.5
)
assert
core_client
.
num_
engines_running
==
0
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
vllm/v1/engine/__init__.py
View file @
1e013fa3
...
...
@@ -61,6 +61,11 @@ class EngineCoreRequest(
arrival_time
:
float
lora_request
:
Optional
[
LoRARequest
]
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave
:
int
=
0
class
EngineCoreEventType
(
enum
.
IntEnum
):
"""The type of engine core request event."""
...
...
@@ -139,8 +144,12 @@ class EngineCoreOutputs(
utility_output
:
Optional
[
UtilityOutput
]
=
None
finished_requests
:
Optional
[
set
[
str
]]
=
None
# In DP case, used to signal that the engine is paused.
engine_paused
:
bool
=
False
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete
:
Optional
[
int
]
=
None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
timestamp
==
0.0
:
...
...
@@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
START_DP
=
b
'
\x02
'
START_DP
_WAVE
=
b
'
\x02
'
UTILITY
=
b
'
\x03
'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED
=
b
'
\x04
'
vllm/v1/engine/core.py
View file @
1e013fa3
...
...
@@ -325,7 +325,7 @@ class EngineCoreProc(EngineCore):
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
global_
un
f
in
ished_reqs
=
False
self
.
engines_r
un
n
in
g
=
False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
...
...
@@ -410,8 +410,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
waited
=
False
while
not
self
.
global_unfinished_reqs
and
not
(
self
.
scheduler
.
has_requests
()):
while
not
self
.
engines_running
and
not
(
self
.
scheduler
.
has_requests
()):
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
logger
.
debug
(
"EngineCore waiting for work."
)
waited
=
True
...
...
@@ -419,10 +418,7 @@ class EngineCoreProc(EngineCore):
self
.
_handle_client_request
(
*
req
)
if
waited
:
logger
.
debug
(
"EngineCore loop active - local unfinished: %s, finished: %s."
,
self
.
scheduler
.
has_unfinished_requests
(),
self
.
scheduler
.
has_finished_requests
())
logger
.
debug
(
"EngineCore loop active."
)
# Handle any more client requests.
while
not
self
.
input_queue
.
empty
():
...
...
@@ -446,10 +442,6 @@ class EngineCoreProc(EngineCore):
self
.
add_request
(
request
)
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
START_DP
:
if
not
self
.
global_unfinished_reqs
:
logger
.
debug
(
"EngineCore starting idle loop."
)
self
.
global_unfinished_reqs
=
True
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
...
...
@@ -548,9 +540,6 @@ class EngineCoreProc(EngineCore):
socket
.
send_multipart
(
buffers
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
...
...
@@ -587,7 +576,9 @@ class DPEngineCoreProc(EngineCoreProc):
for
i
in
range
(
local_dp_rank
*
tp_size
,
(
local_dp_rank
+
1
)
*
tp_size
))
self
.
local_dp_rank
=
local_dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
# Initialize the engine after setting up environment.
super
().
__init__
(
input_path
,
output_path
,
vllm_config
,
executor_class
,
...
...
@@ -602,6 +593,31 @@ class DPEngineCoreProc(EngineCoreProc):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
add_request
(
self
,
request
:
EngineCoreRequest
):
if
request
.
current_wave
!=
self
.
current_wave
:
if
request
.
current_wave
>
self
.
current_wave
:
self
.
current_wave
=
request
.
current_wave
elif
not
self
.
engines_running
:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
super
().
add_request
(
request
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
if
request_type
==
EngineCoreRequestType
.
START_DP_WAVE
:
new_wave
:
int
=
request
if
new_wave
>=
self
.
current_wave
:
self
.
current_wave
=
new_wave
if
not
self
.
engines_running
:
logger
.
debug
(
"EngineCore starting idle loop for wave %d."
,
new_wave
)
self
.
engines_running
=
True
else
:
super
().
_handle_client_request
(
request_type
,
request
)
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore for data parallel case."""
...
...
@@ -628,7 +644,7 @@ class DPEngineCoreProc(EngineCoreProc):
# up-to-date state is returned in the engine outputs.
self
.
_process_engine_step
()
if
not
self
.
global_
un
f
in
ished_reqs
:
if
not
self
.
engines_r
un
n
in
g
:
# All engines are idle.
continue
...
...
@@ -637,18 +653,23 @@ class DPEngineCoreProc(EngineCoreProc):
self
.
execute_dummy_batch
()
# 3) All-reduce operation to determine global unfinished reqs.
self
.
global_
un
f
in
ished_reqs
=
self
.
_has_global_unfinished_reqs
(
self
.
engines_r
un
n
in
g
=
self
.
_has_global_unfinished_reqs
(
local_unfinished_reqs
)
if
not
self
.
global_unfinished_reqs
:
# Notify client that we are pausing the loop.
self
.
output_queue
.
put_nowait
(
ENGINE_PAUSED_OUTPUTS
)
if
not
self
.
engines_running
:
if
self
.
local_dp_rank
==
0
:
# Notify client that we are pausing the loop.
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
self
.
current_wave
)
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
))
self
.
current_wave
+=
1
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
# Optimization - only perform finish-sync all-reduce every
16
steps.
# Optimization - only perform finish-sync all-reduce every
24
steps.
self
.
counter
+=
1
if
self
.
counter
!=
16
:
if
self
.
counter
!=
24
:
return
True
self
.
counter
=
0
...
...
vllm/v1/engine/core_client.py
View file @
1e013fa3
...
...
@@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
self
.
num_engines_running
=
0
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
*
self
.
encoder
.
encode
(
None
))
assert
len
(
self
.
core_engines
)
>
1
def
_init_core_engines
(
...
...
@@ -829,23 +826,23 @@ class DPAsyncMPClient(AsyncMPClient):
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
*
self
.
encoder
.
encode
(
request
))
request
.
current_wave
=
self
.
current_wave
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
if
self
.
num_engines_running
>=
len
(
self
.
core_engines
):
await
self
.
_send_input_message
(
msg
,
chosen_engine
)
else
:
to_await
=
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
,
chosen_engine
)
if
not
self
.
engines_running
:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self
.
num_
engines_running
+
=
len
(
self
.
core_engines
)
await
asyncio
.
gather
(
*
[
self
.
_send_input_message
(
msg
if
engine
is
chosen_engine
else
self
.
start_dp_msg
,
engine
)
for
engine
in
self
.
core_engines
])
self
.
engines_running
=
True
to_
await
=
asyncio
.
gather
(
to_await
,
# type: ignore[assignment]
*
self
.
_start_wave_coros
(
exclude_index
=
chosen_engine
.
index
))
await
to_await
self
.
_ensure_output_queue_task
()
...
...
@@ -860,21 +857,31 @@ class DPAsyncMPClient(AsyncMPClient):
if
engine
:
=
self
.
reqs_in_flight
.
pop
(
req_id
,
None
):
engine
.
num_reqs_in_flight
-=
1
if
outputs
.
engine_paused
:
assert
self
.
num_engines_running
>=
1
self
.
num_engines_running
-=
1
if
not
self
.
num_engines_running
and
self
.
reqs_in_flight
:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self
.
num_engines_running
=
len
(
self
.
core_engines
)
coros
=
[
self
.
_send_input_message
(
self
.
start_dp_msg
,
engine
)
for
engine
in
self
.
core_engines
if
not
engine
.
num_reqs_in_flight
]
if
coros
:
await
asyncio
.
gather
(
*
coros
)
if
outputs
.
wave_complete
is
not
None
:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if
self
.
current_wave
<=
outputs
.
wave_complete
:
self
.
current_wave
=
outputs
.
wave_complete
+
1
self
.
engines_running
=
False
elif
outputs
.
start_wave
is
not
None
and
(
outputs
.
start_wave
>
self
.
current_wave
or
(
outputs
.
start_wave
==
self
.
current_wave
and
not
self
.
engines_running
)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self
.
current_wave
=
outputs
.
start_wave
self
.
engines_running
=
True
await
asyncio
.
gather
(
*
self
.
_start_wave_coros
(
exclude_index
=
outputs
.
engine_index
))
def
_start_wave_coros
(
self
,
exclude_index
:
int
)
->
list
[
Awaitable
[
None
]]:
logger
.
debug
(
"Sending start DP wave %d."
,
self
.
current_wave
)
return
[
self
.
_send_input
(
EngineCoreRequestType
.
START_DP_WAVE
,
self
.
current_wave
,
engine
)
for
engine
in
self
.
core_engines
if
engine
.
index
!=
exclude_index
]
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
not
request_ids
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment