Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dbf0da81
Unverified
Commit
dbf0da81
authored
Feb 24, 2026
by
Nick Hill
Committed by
GitHub
Feb 24, 2026
Browse files
[Core] Cleanup engine pause/sleep logic (#34528)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
3bbb2046
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
302 additions
and
198 deletions
+302
-198
tests/v1/distributed/test_async_llm_dp.py
tests/v1/distributed/test_async_llm_dp.py
+120
-46
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+7
-12
vllm/engine/protocol.py
vllm/engine/protocol.py
+1
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+41
-47
vllm/entrypoints/serve/sleep/api_router.py
vllm/entrypoints/serve/sleep/api_router.py
+2
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+9
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+108
-62
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+11
-8
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+3
-3
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+0
-14
No files found.
tests/v1/distributed/test_async_llm_dp.py
View file @
dbf0da81
...
...
@@ -3,8 +3,10 @@
import
asyncio
import
os
import
time
from
contextlib
import
ExitStack
from
dataclasses
import
dataclass
from
typing
import
Any
import
pytest
...
...
@@ -187,24 +189,33 @@ async def test_load(
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
# When expert_parallel=False: uses non-MoE model (DP replicas as separate engines).
# When expert_parallel=True: uses MoE model + EP (DPEngineCoreProc, sync pause path).
DP_PAUSE_MODEL
=
"hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_MODEL_MOE
=
"ibm-research/PowerMoE-3b"
DP_PAUSE_PROMPT
=
"This is a test of data parallel pause"
def
_get_dp_pause_engine_args
(
expert_parallel
:
bool
)
->
AsyncEngineArgs
:
"""Engine args for DP pause tests: MoE+EP when expert_parallel else small Llama."""
model
=
DP_PAUSE_MODEL_MOE
if
expert_parallel
else
DP_PAUSE_MODEL
return
AsyncEngineArgs
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
DP_SIZE
,
data_parallel_backend
=
"mp"
,
enable_expert_parallel
=
expert_parallel
,
)
@
pytest
.
mark
.
asyncio
async
def
test_dp_pause_resume_basic
():
@
pytest
.
mark
.
parametrize
(
"expert_parallel"
,
[
False
,
True
])
async
def
test_dp_pause_resume_basic
(
expert_parallel
:
bool
):
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"DP pause tests use mp backend only"
)
with
ExitStack
()
as
after
:
engine_args
=
AsyncEngineArgs
(
model
=
DP_PAUSE_MODEL
,
enforce_eager
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
DP_SIZE
,
data_parallel_backend
=
"mp"
,
)
engine_args
=
_get_dp_pause_engine_args
(
expert_parallel
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
...
...
@@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic():
@
pytest
.
mark
.
asyncio
async
def
test_dp_pause_abort
():
@
pytest
.
mark
.
parametrize
(
"expert_parallel"
,
[
False
,
True
])
async
def
test_dp_pause_abort
(
expert_parallel
:
bool
):
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"DP pause tests use mp backend only"
)
with
ExitStack
()
as
after
:
engine_args
=
AsyncEngineArgs
(
model
=
DP_PAUSE_MODEL
,
enforce_eager
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
DP_SIZE
,
data_parallel_backend
=
"mp"
,
)
engine_args
=
_get_dp_pause_engine_args
(
expert_parallel
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
...
...
@@ -286,41 +290,111 @@ async def test_dp_pause_abort():
@
pytest
.
mark
.
asyncio
async
def
test_dp_pause_keep_then_resume
():
"""Pause with keep queues new requests; resume allows them to run."""
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"DP pause tests use mp backend only"
)
@
pytest
.
mark
.
parametrize
(
"expert_parallel"
,
[
False
,
True
])
async
def
test_dp_pause_keep_then_resume
(
expert_parallel
:
bool
):
"""Start generation, pause after a few tokens (keep mode), resume; verify gap."""
pause_duration
=
2.0
min_tokens_before_pause
=
3
with
ExitStack
()
as
after
:
engine_args
=
AsyncEngineArgs
(
model
=
DP_PAUSE_MODEL
,
enforce_eager
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
DP_SIZE
,
data_parallel_backend
=
"mp"
,
)
engine_args
=
_get_dp_pause_engine_args
(
expert_parallel
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
await
engine
.
pause_generation
(
mode
=
"keep"
)
assert
await
engine
.
is_paused
()
request_done
=
asyncio
.
Event
()
sampling_params
=
SamplingParams
(
max_tokens
=
15
,
ignore_eos
=
True
)
token_times
:
list
[
tuple
[
int
,
float
]]
=
[]
pause_token_idx
=
0
async
def
gen
():
async
for
out
in
engine
.
generate
(
request_id
=
"queued-keep"
,
async
def
generator_task
():
nonlocal
pause_token_idx
out
=
None
async
for
output
in
engine
.
generate
(
request_id
=
"keep-resume-req"
,
prompt
=
DP_PAUSE_PROMPT
,
sampling_params
=
S
ampling
P
arams
(
max_tokens
=
5
)
,
sampling_params
=
s
ampling
_p
arams
,
):
pass
request_done
.
set
()
token_count
=
len
(
output
.
outputs
[
0
].
token_ids
)
token_times
.
append
((
token_count
,
time
.
monotonic
()))
out
=
output
return
out
task
=
asyncio
.
create_task
(
gen
())
await
asyncio
.
sleep
(
0.2
)
assert
not
request_done
.
is_set
()
async
def
controller_task
():
nonlocal
pause_token_idx
while
len
(
token_times
)
<
min_tokens_before_pause
:
await
asyncio
.
sleep
(
0.01
)
await
engine
.
pause_generation
(
mode
=
"keep"
)
await
asyncio
.
sleep
(
pause_duration
)
pause_token_idx
=
len
(
token_times
)
await
engine
.
resume_generation
()
gen_task
=
asyncio
.
create_task
(
generator_task
())
ctrl_task
=
asyncio
.
create_task
(
controller_task
())
final_output
,
_
=
await
asyncio
.
gather
(
gen_task
,
ctrl_task
)
assert
final_output
is
not
None
and
final_output
.
finished
assert
await
engine
.
is_paused
()
is
False
assert
pause_token_idx
>=
min_tokens_before_pause
if
pause_token_idx
>
0
and
pause_token_idx
<
len
(
token_times
):
pause_gap
=
(
token_times
[
pause_token_idx
][
1
]
-
token_times
[
pause_token_idx
-
1
][
1
]
)
assert
pause_gap
>=
pause_duration
*
0.8
,
(
f
"Expected gap ~
{
pause_duration
}
s after pause, got
{
pause_gap
:.
3
f
}
s"
)
@
pytest
.
mark
.
asyncio
async
def
test_dp_pause_keep_race_staggered_engines
():
"""Race: send pause(keep) to engine 0, then add two requests,
then pause(keep) to engine 1. Ensures no deadlock when pause
requests are staggered and requests arrive in between."""
if
DP_SIZE
!=
2
:
pytest
.
skip
(
"test_dp_pause_keep_race_staggered_engines requires DP_SIZE=2"
)
with
ExitStack
()
as
after
:
engine_args
=
_get_dp_pause_engine_args
(
expert_parallel
=
True
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
client
=
engine
.
engine_core
original_call_utility
=
client
.
call_utility_async
mid_pause_tasks
:
list
[
asyncio
.
Task
]
=
[]
async
def
staggered_pause_keep
(
method
:
str
,
*
args
)
->
Any
:
if
method
!=
"pause_scheduler"
or
not
args
or
args
[
0
]
!=
"keep"
:
return
await
original_call_utility
(
method
,
*
args
)
# Send pause(keep) to engine 0 first
await
client
.
_call_utility_async
(
method
,
*
args
,
engine
=
client
.
core_engines
[
0
]
)
# In the middle: send two requests (race window)
sp
=
SamplingParams
(
max_tokens
=
5
,
ignore_eos
=
True
)
async
def
consume_gen
(
req_id
:
str
)
->
None
:
async
for
_
in
engine
.
generate
(
request_id
=
req_id
,
prompt
=
DP_PAUSE_PROMPT
,
sampling_params
=
sp
,
):
pass
t1
=
asyncio
.
create_task
(
consume_gen
(
"race-1"
))
t2
=
asyncio
.
create_task
(
consume_gen
(
"race-2"
))
mid_pause_tasks
.
extend
([
t1
,
t2
])
await
asyncio
.
sleep
(
3
)
# Then send pause(keep) to engine 1
result
=
await
client
.
_call_utility_async
(
method
,
*
args
,
engine
=
client
.
core_engines
[
1
]
)
return
result
client
.
call_utility_async
=
staggered_pause_keep
await
engine
.
pause_generation
(
mode
=
"keep"
)
assert
await
engine
.
is_paused
()
await
engine
.
resume_generation
()
final
=
await
asyncio
.
wait_for
(
task
,
timeout
=
10.0
)
assert
final
.
finished
assert
not
await
engine
.
is_paused
()
# Let the two requests we sent mid-pause complete
await
asyncio
.
gather
(
*
mid_pause_tasks
)
tests/v1/engine/test_engine_core_client.py
View file @
dbf0da81
...
...
@@ -280,20 +280,15 @@ def echo_dc_nested(
def
future_echo
(
self
,
value
:
Any
,
num_wait_loops
:
int
=
2
)
->
Future
:
"""Utility that returns a Future completed
by a per_step_hook after
num_wait_loops engine steps
(tests deferred utility path).
"""Utility that returns a Future completed
once the engine is idle
(tests deferred utility path).
"""
future
:
Future
=
Future
()
remaining
=
[
num_wait_loops
]
def
_step
(
engine
:
EngineCore
)
->
bool
:
remaining
[
0
]
-=
1
if
remaining
[
0
]
<=
0
:
future
.
set_result
(
value
)
return
True
# remove hook
return
False
def
idle
(
engine
:
EngineCore
):
future
.
set_result
(
value
)
self
.
per_step_hooks
.
add
(
_step
)
self
.
_idle_state_callbacks
.
append
(
idle
)
return
future
...
...
@@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async(
monkeypatch
:
pytest
.
MonkeyPatch
,
subprocess_future_echo_patch
,
):
"""Test that a utility returning a Future
(
complete
d by a per_step_hook
after N steps) completes when the future is done
(engine uses add_done_callback).
"""Test that a utility returning a Future complete
s when the future is done
(engine uses add_done_callback).
"""
with
monkeypatch
.
context
()
as
m
:
m
.
setattr
(
EngineCore
,
"future_echo"
,
future_echo
,
raising
=
False
)
...
...
vllm/engine/protocol.py
View file @
dbf0da81
...
...
@@ -148,7 +148,7 @@ class EngineClient(ABC):
...
@
abstractmethod
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
"PauseMode"
=
"abort"
)
->
None
:
"""Sleep the engine"""
...
...
...
vllm/entrypoints/llm.py
View file @
dbf0da81
...
...
@@ -87,6 +87,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils.counter
import
Counter
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
vllm.utils.tqdm_utils
import
maybe_tqdm
from
vllm.v1.engine
import
PauseMode
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
...
...
@@ -441,8 +442,7 @@ class LLM:
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
"""
model_config
=
self
.
model_config
runner_type
=
model_config
.
runner_type
runner_type
=
self
.
model_config
.
runner_type
if
runner_type
!=
"generate"
:
raise
ValueError
(
"LLM.generate() is only supported for generative models. "
...
...
@@ -489,46 +489,22 @@ class LLM:
Returns:
A list of request IDs for the enqueued requests.
"""
model_config
=
self
.
model_config
runner_type
=
model_config
.
runner_type
runner_type
=
self
.
model_config
.
runner_type
if
runner_type
!=
"generate"
:
raise
ValueError
(
"LLM.enqueue() is only supported for generative models."
)
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
# Use the same preprocessing as _run_completion
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_params
=
self
.
_params_to_seq
(
sampling_params
,
len
(
seq_prompts
))
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
seq_prompts
))
seq_tok_kwargs
=
[
merge_kwargs
(
tokenization_kwargs
,
dict
(
truncate_prompt_tokens
=
param
.
truncate_prompt_tokens
),
)
for
param
in
seq_params
]
seq_priority
=
self
.
_priority_to_seq
(
priority
,
len
(
prompts
))
request_ids
=
self
.
_render_and_add_requests
(
prompts
=
(
self
.
_preprocess_cmpl_one
(
prompt
,
tok_kwargs
)
for
prompt
,
tok_kwargs
in
zip
(
maybe_tqdm
(
seq_prompts
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering prompts"
,
),
seq_tok_kwargs
,
)
),
params
=
seq_params
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
return
self
.
_add_completion_requests
(
prompts
=
prompts
,
params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
priority
=
priority
,
tokenization_kwargs
=
tokenization_kwargs
,
)
return
request_ids
@
overload
def
wait_for_completion
(
self
,
...
...
@@ -1659,7 +1635,7 @@ class LLM:
reset_running_requests
,
reset_connector
)
def
sleep
(
self
,
level
:
int
=
1
):
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
...
...
@@ -1679,10 +1655,10 @@ class LLM:
a different model or update the model, where
previous model weights are not needed. It reduces
CPU memory pressure.
mode: How to handle any existing requests, can be "abort", "wait",
or "keep".
"""
if
level
>
0
:
self
.
reset_prefix_cache
()
self
.
llm_engine
.
sleep
(
level
=
level
)
self
.
llm_engine
.
sleep
(
level
=
level
,
mode
=
mode
)
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
):
"""
...
...
@@ -1759,19 +1735,18 @@ class LLM:
return
[
0
]
*
num_requests
def
_
run
_completion
(
def
_
add
_completion
_requests
(
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
params
:
SamplingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
)
->
list
[
str
]
:
seq_prompts
=
prompt_to_seq
(
prompts
)
seq_params
=
self
.
_params_to_seq
(
params
,
len
(
seq_prompts
))
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
seq_prompts
))
...
...
@@ -1784,25 +1759,44 @@ class LLM:
]
seq_priority
=
self
.
_priority_to_seq
(
priority
,
len
(
prompts
))
return
self
.
_render_and_
run
_requests
(
return
self
.
_render_and_
add
_requests
(
prompts
=
(
self
.
_preprocess_cmpl_one
(
prompt
,
tok_kwargs
)
for
prompt
,
tok_kwargs
in
zip
(
maybe_tqdm
(
seq_prompts
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering prompts"
,
seq_prompts
,
use_tqdm
=
use_tqdm
,
desc
=
"Rendering prompts"
),
seq_tok_kwargs
,
)
),
params
=
seq_params
,
output_type
=
output_type
,
use_tqdm
=
use_tqdm
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
)
def
_run_completion
(
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
params
:
SamplingParams
|
PoolingParams
|
Sequence
[
SamplingParams
|
PoolingParams
],
output_type
:
type
[
_O
],
*
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
_add_completion_requests
(
prompts
=
prompts
,
params
=
params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
priority
=
priority
,
tokenization_kwargs
=
tokenization_kwargs
,
)
return
self
.
_run_engine
(
use_tqdm
=
use_tqdm
,
output_type
=
output_type
)
def
_run_chat
(
self
,
messages
:
list
[
ChatCompletionMessageParam
]
...
...
vllm/entrypoints/serve/sleep/api_router.py
View file @
dbf0da81
...
...
@@ -23,7 +23,8 @@ router = APIRouter()
async
def
sleep
(
raw_request
:
Request
):
# get POST params
level
=
raw_request
.
query_params
.
get
(
"level"
,
"1"
)
await
engine_client
(
raw_request
).
sleep
(
int
(
level
))
mode
=
raw_request
.
query_params
.
get
(
"mode"
,
"abort"
)
await
engine_client
(
raw_request
).
sleep
(
int
(
level
),
mode
)
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
...
...
vllm/v1/engine/async_llm.py
View file @
dbf0da81
...
...
@@ -753,6 +753,13 @@ class AsyncLLM(EngineClient):
)
mode
=
"wait"
await
self
.
engine_core
.
pause_scheduler_async
(
mode
=
mode
,
clear_cache
=
clear_cache
)
# Small sleep to help ensure that final outputs from any in-flight requests are
# returned prior to this method returning. These outputs come out of the engine
# prior to the wait-for-idle completion event, but involve additional async
# tasks in output processing.
# Note that this is not required for correctness, just more intuitive ordering
# of events from caller's pov.
await
asyncio
.
sleep
(
0.02
)
async
def
resume_generation
(
self
)
->
None
:
"""Resume generation after :meth:`pause_generation`."""
...
...
@@ -890,10 +897,8 @@ class AsyncLLM(EngineClient):
async
def
reset_encoder_cache
(
self
)
->
None
:
await
self
.
engine_core
.
reset_encoder_cache_async
()
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
if
level
>
0
:
await
self
.
reset_prefix_cache
()
await
self
.
engine_core
.
sleep_async
(
level
)
async
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
await
self
.
engine_core
.
sleep_async
(
level
,
mode
)
if
self
.
logger_manager
is
not
None
:
self
.
logger_manager
.
record_sleep_state
(
1
,
level
)
...
...
vllm/v1/engine/core.py
View file @
dbf0da81
...
...
@@ -9,6 +9,7 @@ from collections import defaultdict, deque
from
collections.abc
import
Callable
,
Generator
from
concurrent.futures
import
Future
from
contextlib
import
ExitStack
,
contextmanager
from
functools
import
partial
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
typing
import
Any
,
TypeVar
,
cast
...
...
@@ -211,7 +212,7 @@ class EngineCore:
self
.
aborts_queue
=
queue
.
Queue
[
list
[
str
]]()
self
.
per_step_hoo
ks
:
s
e
t
[
Callable
]
=
set
()
self
.
_idle_state_callbac
ks
:
li
st
[
Callable
]
=
[]
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
...
...
@@ -592,21 +593,51 @@ class EngineCore:
# Reset the GPU model runner's encoder cache (physical storage)
self
.
model_executor
.
reset_encoder_cache
()
def
_reset_caches
(
self
,
reset_running_requests
=
True
)
->
None
:
self
.
reset_prefix_cache
(
reset_running_requests
=
reset_running_requests
)
self
.
reset_mm_cache
()
self
.
reset_encoder_cache
()
def
pause_scheduler
(
self
,
mode
:
PauseMode
=
"abort"
,
clear_cache
:
bool
=
True
)
->
Future
[
Any
]
|
None
:
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
)
->
Future
|
None
:
"""Pause generation; behavior depends on mode.
All pause modes queue new adds -- "abort" and "keep" skip step();
"wait" allows step() so in-flight requests can drain.
- ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
outputs to be sent (when running with output_queue), optionally
clear caches, then complete the returned Future.
- ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
optionally clear caches, then complete the returned Future.
- ``keep``: Set PAUSED_ALL; return a Future that completes when the
output queue is empty.
"""
if
mode
not
in
(
"keep"
,
"abort"
,
"wait"
):
raise
ValueError
(
f
"Invalid pause mode:
{
mode
}
"
)
if
mode
==
"wait"
:
raise
ValueError
(
"'wait' mode can't be used in inproc-engine mode"
)
if
mode
==
"abort"
:
self
.
scheduler
.
finish_requests
(
None
,
RequestStatus
.
FINISHED_ABORTED
)
pause_state
=
PauseState
.
PAUSED_ALL
if
mode
==
"keep"
else
PauseState
.
PAUSED_NEW
self
.
scheduler
.
set_pause_state
(
pause_state
)
if
clear_cache
:
self
.
_reset_caches
()
return
None
def
resume_scheduler
(
self
)
->
None
:
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
"""Resume the scheduler and flush any requests queued while paused."""
self
.
scheduler
.
set_pause_state
(
PauseState
.
UNPAUSED
)
def
is_scheduler_paused
(
self
)
->
bool
:
"""Return whether the scheduler is in any pause state. False in base EngineCore
and overridden in EngineCoreProc."""
return
False
"""Return whether the scheduler is in any pause state."""
return
self
.
scheduler
.
pause_state
!=
PauseState
.
UNPAUSED
def
sleep
(
self
,
level
:
int
=
1
)
:
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
|
Future
:
"""Put the engine to sleep at the specified level.
Args:
...
...
@@ -615,13 +646,34 @@ class EngineCore:
but not processed. No GPU memory changes.
- Level 1: Offload model weights to CPU, discard KV cache.
- Level 2: Discard all GPU memory.
mode: Pause mode - how to deal with any existing requests, see
documentation of pause_scheduler method.
"""
if
level
==
0
:
# Level 0: Just pause scheduling, don't touch GPU
self
.
pause_scheduler
()
else
:
# Level 1+: Delegate to executor for GPU memory management
self
.
model_executor
.
sleep
(
level
)
# Pause scheduler before sleeping.
clear_prefix_cache
=
level
>=
1
pause_future
=
self
.
pause_scheduler
(
mode
=
mode
,
clear_cache
=
clear_prefix_cache
)
if
level
<
1
:
return
pause_future
# Level 1+: Delegate to executor for GPU memory management
model_executor
=
self
.
model_executor
if
pause_future
is
None
:
model_executor
.
sleep
(
level
)
return
None
future
=
Future
[
Any
]()
def
pause_complete
(
f
:
Future
):
try
:
f
.
result
()
# propagate any exception
future
.
set_result
(
model_executor
.
sleep
(
level
))
except
Exception
as
e
:
future
.
set_exception
(
e
)
logger
.
info
(
"Waiting for in-flight requests to complete before sleeping..."
)
pause_future
.
add_done_callback
(
pause_complete
)
return
future
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
):
"""Wake up the engine from sleep.
...
...
@@ -630,17 +682,15 @@ class EngineCore:
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
"""
if
tags
is
not
None
and
"scheduling"
in
tags
:
# Level 0 wake up: Resume scheduling
self
.
resume_scheduler
()
# Remove "scheduling" from tags if there are other tags to process
remaining_tags
=
[
t
for
t
in
tags
if
t
!=
"scheduling"
]
if
remaining_tags
:
self
.
model_executor
.
wake_up
(
remaining_tags
)
else
:
# Full wake up
self
.
resume_scheduler
()
# Remove "scheduling" from tags if there are other tags to process.
tags
=
[
t
for
t
in
tags
if
t
!=
"scheduling"
]
if
tags
is
None
or
tags
:
self
.
model_executor
.
wake_up
(
tags
)
# Resume scheduling (applies to all levels)
self
.
resume_scheduler
()
def
is_sleeping
(
self
)
->
bool
:
"""Check if engine is sleeping at any level."""
return
self
.
is_scheduler_paused
()
or
self
.
model_executor
.
is_sleeping
...
...
@@ -1038,6 +1088,14 @@ class EngineCoreProc(EngineCore):
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
pass
def
has_work
(
self
)
->
bool
:
"""Returns true if the engine should be stepped."""
return
(
self
.
engines_running
or
self
.
scheduler
.
has_requests
()
or
bool
(
self
.
batch_queue
)
)
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
...
...
@@ -1047,19 +1105,14 @@ class EngineCoreProc(EngineCore):
self
.
_process_input_queue
()
# 2) Step the engine core and return the outputs.
self
.
_process_engine_step
()
# 3) Run any per-step hooks.
self
.
_process_per_step_hooks
()
def
_process_input_queue
(
self
):
"""Exits when an engine step needs to be performed."""
waited
=
False
while
(
not
self
.
engines_running
and
not
self
.
scheduler
.
has_requests
()
and
not
self
.
batch_queue
and
not
self
.
per_step_hooks
):
while
not
self
.
has_work
():
# Notify callbacks waiting for engine to become idle.
self
.
_notify_idle_state_callbacks
()
if
self
.
input_queue
.
empty
():
# Drain aborts queue; all aborts are also processed via input_queue.
with
self
.
aborts_queue
.
mutex
:
...
...
@@ -1098,12 +1151,10 @@ class EngineCoreProc(EngineCore):
return
model_executed
def
_process_per_step_hooks
(
self
)
->
None
:
if
self
.
per_step_hooks
:
for
hook
in
list
(
self
.
per_step_hooks
):
finished
=
hook
(
self
)
if
finished
:
self
.
per_step_hooks
.
discard
(
hook
)
def
_notify_idle_state_callbacks
(
self
)
->
None
:
while
self
.
_idle_state_callbacks
:
callback
=
self
.
_idle_state_callbacks
.
pop
()
callback
(
self
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
...
...
@@ -1377,19 +1428,10 @@ class EngineCoreProc(EngineCore):
if
mode
not
in
(
"keep"
,
"abort"
,
"wait"
):
raise
ValueError
(
f
"Invalid pause mode:
{
mode
}
"
)
future
:
Future
[
Any
]
=
Future
()
def
wait_until_idle
(
engine
:
"EngineCoreProc"
)
->
bool
:
scheduler
=
engine
.
scheduler
out_queue
=
engine
.
output_queue
if
scheduler
.
has_requests
()
or
engine
.
batch_queue
or
not
out_queue
.
empty
():
return
False
def
engine_idle_callback
(
engine
:
"EngineCoreProc"
,
future
:
Future
[
Any
])
->
None
:
if
clear_cache
:
engine
.
reset_prefix_cache
(
reset_running_requests
=
True
)
engine
.
reset_mm_cache
()
engine
.
reset_encoder_cache
()
engine
.
_reset_caches
()
future
.
set_result
(
None
)
return
True
if
mode
==
"abort"
:
aborted_reqs
=
self
.
scheduler
.
finish_requests
(
...
...
@@ -1399,12 +1441,17 @@ class EngineCoreProc(EngineCore):
pause_state
=
PauseState
.
PAUSED_ALL
if
mode
==
"keep"
else
PauseState
.
PAUSED_NEW
self
.
scheduler
.
set_pause_state
(
pause_state
)
if
not
wait_until_idle
(
self
):
self
.
per_step_hooks
.
add
(
wait_until_idle
)
return
future
return
None
if
not
self
.
has_work
():
if
clear_cache
:
self
.
_reset_caches
()
return
None
future
=
Future
[
Any
]()
self
.
_idle_state_callbacks
.
append
(
partial
(
engine_idle_callback
,
future
=
future
))
return
future
def
_send_abort_outputs
(
self
,
aborted_reqs
:
list
[
tuple
[
str
,
int
]])
->
None
:
# TODO(nick) this will be moved inside the scheduler
if
aborted_reqs
:
# Map client_index to list of request_ids that belong to that client.
by_client
=
defaultdict
[
int
,
set
[
str
]](
set
)
...
...
@@ -1418,14 +1465,6 @@ class EngineCoreProc(EngineCore):
eco
=
EngineCoreOutputs
(
finished_requests
=
req_ids
,
outputs
=
outputs
)
self
.
output_queue
.
put_nowait
((
client_index
,
eco
))
def
resume_scheduler
(
self
)
->
None
:
"""Resume the scheduler and flush any requests queued while paused."""
self
.
scheduler
.
set_pause_state
(
PauseState
.
UNPAUSED
)
def
is_scheduler_paused
(
self
)
->
bool
:
"""Return whether the scheduler is in any pause state."""
return
self
.
scheduler
.
pause_state
!=
PauseState
.
UNPAUSED
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
...
...
@@ -1481,6 +1520,7 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
add_request
(
self
,
request
:
Request
,
request_wave
:
int
=
0
):
super
().
add_request
(
request
,
request_wave
)
if
self
.
has_coordinator
and
request_wave
!=
self
.
current_wave
:
if
request_wave
>
self
.
current_wave
:
self
.
current_wave
=
request_wave
...
...
@@ -1491,7 +1531,13 @@ class DPEngineCoreProc(EngineCoreProc):
(
-
1
,
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
)
super
().
add_request
(
request
,
request_wave
)
def
resume_scheduler
(
self
):
super
().
resume_scheduler
()
if
not
self
.
engines_running
and
self
.
scheduler
.
has_unfinished_requests
():
# Wake up other DP engines.
self
.
output_queue
.
put_nowait
(
(
-
1
,
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
...
...
@@ -1532,8 +1578,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
executed
=
self
.
_process_engine_step
()
self
.
_maybe_publish_request_counts
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
if
not
executed
:
if
not
local_unfinished_reqs
and
not
self
.
engines_running
:
# All engines are idle.
...
...
vllm/v1/engine/core_client.py
View file @
dbf0da81
...
...
@@ -150,7 +150,7 @@ class EngineCoreClient(ABC):
def
reset_encoder_cache
(
self
)
->
None
:
raise
NotImplementedError
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
raise
NotImplementedError
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
)
->
None
:
...
...
@@ -227,7 +227,7 @@ class EngineCoreClient(ABC):
async
def
reset_encoder_cache_async
(
self
)
->
None
:
raise
NotImplementedError
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep_async
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
raise
NotImplementedError
async
def
wake_up_async
(
self
,
tags
:
list
[
str
]
|
None
=
None
)
->
None
:
...
...
@@ -314,8 +314,11 @@ class InprocClient(EngineCoreClient):
def
reset_encoder_cache
(
self
)
->
None
:
self
.
engine_core
.
reset_encoder_cache
()
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine_core
.
sleep
(
level
)
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
if
mode
==
"wait"
:
raise
ValueError
(
"'wait' pause mode is not supported in inproc-engine mode"
)
result
=
self
.
engine_core
.
sleep
(
level
,
mode
)
assert
result
is
None
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
)
->
None
:
self
.
engine_core
.
wake_up
(
tags
)
...
...
@@ -796,8 +799,8 @@ class SyncMPClient(MPClient):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
call_utility
(
"pin_lora"
,
lora_id
)
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
call_utility
(
"sleep"
,
level
)
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
self
.
call_utility
(
"sleep"
,
level
,
mode
)
def
wake_up
(
self
,
tags
:
list
[
str
]
|
None
=
None
)
->
None
:
self
.
call_utility
(
"wake_up"
,
tags
)
...
...
@@ -1009,8 +1012,8 @@ class AsyncMPClient(MPClient):
async
def
reset_encoder_cache_async
(
self
)
->
None
:
await
self
.
call_utility_async
(
"reset_encoder_cache"
)
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
await
self
.
call_utility_async
(
"sleep"
,
level
)
async
def
sleep_async
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
)
->
None
:
await
self
.
call_utility_async
(
"sleep"
,
level
,
mode
)
async
def
wake_up_async
(
self
,
tags
:
list
[
str
]
|
None
=
None
)
->
None
:
await
self
.
call_utility_async
(
"wake_up"
,
tags
)
...
...
vllm/v1/engine/llm_engine.py
View file @
dbf0da81
...
...
@@ -28,7 +28,7 @@ from vllm.tasks import SupportedTask
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tracing
import
init_tracer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
,
PauseMode
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.input_processor
import
InputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
...
...
@@ -355,8 +355,8 @@ class LLMEngine:
"""
self
.
engine_core
.
reset_encoder_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
self
.
engine_core
.
sleep
(
level
)
def
sleep
(
self
,
level
:
int
=
1
,
mode
:
PauseMode
=
"abort"
):
self
.
engine_core
.
sleep
(
level
,
mode
)
if
self
.
logger_manager
is
not
None
:
self
.
logger_manager
.
record_sleep_state
(
1
,
level
)
...
...
vllm/v1/engine/output_processor.py
View file @
dbf0da81
...
...
@@ -429,8 +429,6 @@ class OutputProcessor:
self
.
external_req_ids
:
defaultdict
[
str
,
list
[
str
]]
=
defaultdict
(
list
)
self
.
lora_states
=
LoRARequestStates
(
log_stats
)
self
.
tracing_enabled
=
tracing_enabled
self
.
_requests_drained
=
asyncio
.
Event
()
self
.
_requests_drained
.
set
()
def
get_num_unfinished_requests
(
self
):
return
len
(
self
.
request_states
)
...
...
@@ -438,11 +436,6 @@ class OutputProcessor:
def
has_unfinished_requests
(
self
)
->
bool
:
return
len
(
self
.
request_states
)
>
0
async
def
wait_for_requests_to_drain
(
self
)
->
None
:
if
not
self
.
request_states
:
return
await
self
.
_requests_drained
.
wait
()
def
propagate_error
(
self
,
e
:
Exception
):
"""Propagate error to all generate() tasks."""
...
...
@@ -510,8 +503,6 @@ class OutputProcessor:
child_reqs
=
self
.
abort_requests
(
child_reqs
,
internal
=
True
)
request_ids_to_abort
.
extend
(
child_reqs
)
self
.
parent_requests
.
pop
(
request_id
,
None
)
if
not
self
.
request_states
:
self
.
_requests_drained
.
set
()
return
request_ids_to_abort
def
add_request
(
...
...
@@ -538,8 +529,6 @@ class OutputProcessor:
log_stats
=
self
.
log_stats
,
stream_interval
=
self
.
stream_interval
,
)
if
self
.
_requests_drained
.
is_set
():
self
.
_requests_drained
.
clear
()
self
.
request_states
[
request_id
]
=
req_state
if
parent_req
:
self
.
parent_requests
[
parent_req
.
request_id
]
=
parent_req
...
...
@@ -706,9 +695,6 @@ class OutputProcessor:
if
parent_req
and
not
parent_req
.
child_requests
:
self
.
parent_requests
.
pop
(
parent_req
.
request_id
,
None
)
if
not
self
.
request_states
:
self
.
_requests_drained
.
set
()
def
update_scheduler_stats
(
self
,
scheduler_stats
:
SchedulerStats
|
None
):
self
.
lora_states
.
update_scheduler_stats
(
scheduler_stats
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment