Unverified Commit dddbff46 authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Core] Move pause and resume functions into engine (#34125)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Signed-off-by: default avatarhao-aaron <ahao@anyscale.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 47e9b63e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test pause/resume with Data Parallel (DP) via HTTP API.
This example demonstrates coordinated pause/resume across multiple DP ranks.
The pause synchronizes across all DP engines via all-reduce.
Prerequisites:
Start a vLLM server with data parallelism:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--data-parallel-size 4 \
--tensor-parallel-size 1
Then run this script:
$ python data_parallel_pause_resume.py
The test verifies pause works by:
1. Starting a streaming generation request
2. Pausing the server mid-generation
3. Sleeping for PAUSE_DURATION seconds
4. Resuming the server
5. Verifying there was a gap in token generation matching the pause duration
"""
import argparse
import threading
import time
import requests
from openai import OpenAI
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
PAUSE_DURATION = 3.0
def pause_generation(base_url: str, mode: str = "keep") -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, params={"mode": mode}, timeout=60)
response.raise_for_status()
print("Server paused")
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
print("Server resumed")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", default=BASE_URL)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
client = OpenAI(
base_url=f"{args.base_url}/v1",
api_key="EMPTY",
)
prompt = "Write a long story about a dragon. Once upon a time"
token_times: list[float] = []
pause_token_idx = 0
pause_triggered = threading.Event()
def generator_thread():
"""Stream tokens and record timestamps."""
stream = client.completions.create(
model=args.model,
prompt=prompt,
max_tokens=50,
stream=True,
)
for chunk in stream:
if chunk.choices[0].text:
token_times.append(time.monotonic())
token_count = len(token_times)
print(f"Token {token_count}: {chunk.choices[0].text!r}")
# Signal controller after some tokens
if token_count >= 5 and not pause_triggered.is_set():
pause_triggered.set()
def controller_thread():
"""Pause and resume the server."""
nonlocal pause_token_idx
# Wait for some tokens
pause_triggered.wait()
print(f"\nPausing server (keep mode) at token {len(token_times)}...")
pause_generation(args.base_url, mode="keep")
pause_token_idx = len(token_times)
print(f"Sleeping for {PAUSE_DURATION}s...")
time.sleep(PAUSE_DURATION)
print("Resuming server...")
resume_generation(args.base_url)
print("Resumed!\n")
# Run both threads
gen_thread = threading.Thread(target=generator_thread)
ctrl_thread = threading.Thread(target=controller_thread)
gen_thread.start()
ctrl_thread.start()
gen_thread.join()
ctrl_thread.join()
# Check gap at the pause point
if pause_token_idx < len(token_times):
pause_gap = token_times[pause_token_idx] - token_times[pause_token_idx - 1]
print(
f"\nGap after pause (token {pause_token_idx} -> "
f"{pause_token_idx + 1}): {pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print("Test passed! Pause synchronized across DP ranks.")
else:
print(f"Test failed! Expected ~{PAUSE_DURATION}s gap, got {pause_gap:.3f}s")
else:
print("Test failed! No tokens were generated after resuming.")
if __name__ == "__main__":
main()
......@@ -12,6 +12,7 @@ from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
......@@ -181,3 +182,145 @@ async def test_load(
assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), (
f"requests are imbalanced: {stats_loggers}"
)
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
@pytest.mark.asyncio
async def test_dp_pause_resume_basic():
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
assert not await engine.is_paused()
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Engine still works after resume
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="after-resume",
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_dp_pause_abort():
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
# Start several requests so they are distributed across ranks
sampling_params = SamplingParams(max_tokens=500, ignore_eos=True)
num_requests = 4
outputs_by_id: dict[str, list[RequestOutput]] = {}
async def gen(rid: str):
out_list: list[RequestOutput] = []
outputs_by_id[rid] = out_list
async for out in engine.generate(
request_id=rid,
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
out_list.append(out)
return out_list[-1] if out_list else None
tasks = [asyncio.create_task(gen(f"req-{i}")) for i in range(num_requests)]
# Wait for some tokens on at least one request
while not any(len(o) >= 2 for o in outputs_by_id.values()):
await asyncio.sleep(0.02)
await engine.pause_generation(mode="abort")
finals = await asyncio.gather(*tasks)
for i, final in enumerate(finals):
assert final is not None, f"req-{i} had no output"
assert final.finished
assert final.outputs[0].finish_reason == "abort"
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# New request completes after resume
async for out in engine.generate(
request_id="after-abort",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
assert out.finished
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.asyncio
async def test_dp_pause_keep_then_resume():
"""Pause with keep queues new requests; resume allows them to run."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
request_done = asyncio.Event()
async def gen():
async for out in engine.generate(
request_id="queued-keep",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_done.set()
return out
task = asyncio.create_task(gen())
await asyncio.sleep(0.2)
assert not request_done.is_set()
await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused()
......@@ -708,9 +708,7 @@ async def test_pause_resume_basic():
# Test all modes with no requests in flight
for mode in ("abort", "wait", "keep"):
await engine.pause_generation(mode=mode)
# "keep" only freezes the scheduler; it does not set _paused
if mode != "keep":
assert await engine.is_paused()
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
......@@ -808,6 +806,53 @@ async def test_pause_abort():
assert final_output2.finished
@pytest.mark.asyncio
async def test_pause_then_abort_queued_request():
"""Test that aborting a request that was submitted while paused (in
_paused_adds_queue) aborts it and notifies the client; the request does
not run after resume.
"""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
request_id = "abort-queued-request"
sampling_params = SamplingParams(max_tokens=20, ignore_eos=True)
outputs: list[RequestOutput] = []
# Pause first so the next add goes to _paused_adds_queue
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
async def gen():
async for out in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(out)
return outputs[-1] if outputs else None
gen_task = asyncio.create_task(gen())
# Give the request time to reach the engine and sit in _paused_adds_queue
await asyncio.sleep(0.2)
# Abort the queued request
await engine.abort(request_id, internal=False)
# Resume so the engine can process and deliver the abort output
await engine.resume_generation()
final_output = await asyncio.wait_for(gen_task, timeout=10.0)
assert final_output is not None
assert final_output.finished
assert final_output.outputs[0].finish_reason == "abort"
# Request was never run, so no tokens
assert len(final_output.outputs[0].token_ids) == 0
@pytest.mark.asyncio
async def test_pause_wait():
"""Test that mode='wait' waits for in-flight requests to complete."""
......
......@@ -8,6 +8,7 @@ import os
import signal
import time
import uuid
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from types import SimpleNamespace
......@@ -278,6 +279,24 @@ def echo_dc_nested(
return structures.get(structure_type, val)
def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
"""Utility that returns a Future completed by a per_step_hook after
num_wait_loops engine steps (tests deferred utility path).
"""
future: Future = Future()
remaining = [num_wait_loops]
def _step(engine: EngineCore) -> bool:
remaining[0] -= 1
if remaining[0] <= 0:
future.set_result(value)
return True # remove hook
return False
self.per_step_hooks.add(_step)
return future
# --- Fixtures for subprocess patching ---
# These create sitecustomize.py files that patch EngineCore in spawned
# subprocesses. This is necessary because ROCm requires 'spawn' multiprocessing
......@@ -383,6 +402,28 @@ def subprocess_echo_dc_nested_patch(monkeypatch, tmp_path):
)
@pytest.fixture
def subprocess_future_echo_patch(monkeypatch, tmp_path):
"""Create sitecustomize.py so spawned subprocesses have future_echo method."""
sc = tmp_path / "sitecustomize.py"
sc.write_text(
"\n".join(
[
"from concurrent.futures import Future",
"from typing import Any",
"",
"from vllm.v1.engine.core import EngineCore",
inspect.getsource(future_echo),
"EngineCore.future_echo = future_echo",
]
)
)
monkeypatch.setenv(
"PYTHONPATH",
os.pathsep.join(filter(None, [str(tmp_path), os.getenv("PYTHONPATH")])),
)
@create_new_process_for_each_test()
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(
......@@ -786,6 +827,48 @@ async def test_engine_core_client_util_method_nested_structures(
client.shutdown()
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_future_utility_async(
monkeypatch: pytest.MonkeyPatch,
subprocess_future_echo_patch,
):
"""Test that a utility returning a Future (completed by a per_step_hook
after N steps) completes when the future is done (engine uses add_done_callback).
"""
with monkeypatch.context() as m:
m.setattr(EngineCore, "future_echo", future_echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
core_client: AsyncMPClient = client
# Completes after 2 engine steps (num_wait_loops=2)
result = await core_client.call_utility_async(
"future_echo", "future_result", 2
)
assert result == "future_result"
# None is a valid result (num_wait_loops=0 → completes on first step)
result = await core_client.call_utility_async("future_echo", None, 0)
assert result is None
finally:
client.shutdown()
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
......@@ -18,6 +19,20 @@ if TYPE_CHECKING:
from vllm.v1.structured_output import StructuredOutputManager
class PauseState(enum.IntEnum):
"""Scheduler pause state.
- UNPAUSED: Normal operation
- PAUSE_NEW: No new requests are scheduled, requests already in
running state are scheduled.
- PAUSE_ALL: No requests are scheduled
"""
UNPAUSED = 0
PAUSED_NEW = 1
PAUSED_ALL = 2
class SchedulerInterface(ABC):
@abstractmethod
def __init__(
......@@ -120,11 +135,11 @@ class SchedulerInterface(ABC):
@abstractmethod
def finish_requests(
self,
request_ids: str | Iterable[str],
request_ids: str | Iterable[str] | None,
finished_status: "RequestStatus",
) -> None:
) -> list[tuple[str, int]]:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
is not in the queue, this method will do nothing for that request.
This method is called in two cases:
1. When the request is aborted by the client.
......@@ -132,8 +147,12 @@ class SchedulerInterface(ABC):
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
request_ids: A single or a list of request IDs, or None to finish all.
finished_status: The finished status of the given requests.
Returns:
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
"""
raise NotImplementedError
......@@ -167,6 +186,16 @@ class SchedulerInterface(ABC):
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@property
@abstractmethod
def pause_state(self) -> PauseState:
"""Current pause state of the scheduler."""
raise NotImplementedError
@abstractmethod
def set_pause_state(self, pause_state: PauseState) -> None:
raise NotImplementedError
@abstractmethod
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
......
......@@ -38,7 +38,7 @@ from vllm.v1.core.encoder_cache_manager import (
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
......@@ -271,6 +271,8 @@ class Scheduler(SchedulerInterface):
vllm_config=self.vllm_config,
)
self._pause_state: PauseState = PauseState.UNPAUSED
def _mamba_block_aligned_split(
self,
request: Request,
......@@ -341,6 +343,10 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
......@@ -530,12 +536,12 @@ class Scheduler(SchedulerInterface):
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
......@@ -802,9 +808,10 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
......@@ -1672,18 +1679,26 @@ class Scheduler(SchedulerInterface):
request.record_event(EngineCoreEventType.QUEUED)
def finish_requests(
self, request_ids: str | Iterable[str], finished_status: RequestStatus
) -> None:
self, request_ids: str | Iterable[str] | None, finished_status: RequestStatus
) -> list[tuple[str, int]]:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
If request_ids is None, all requests will be finished.
Returns:
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids,)
else:
elif request_ids is not None:
request_ids = set(request_ids)
else:
request_ids = self.requests.keys()
running_requests_to_remove = set()
waiting_requests_to_remove = []
......@@ -1723,6 +1738,8 @@ class Scheduler(SchedulerInterface):
request.status = finished_status
self._free_request(request, delay_free_blocks=delay_free_blocks)
return [(r.request_id, r.client_index) for r in valid_requests]
def _free_request(
self, request: Request, delay_free_blocks: bool = False
) -> dict[str, Any] | None:
......@@ -1746,7 +1763,18 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(request)
del self.requests[request.request_id]
@property
def pause_state(self) -> PauseState:
return self._pause_state
def set_pause_state(self, pause_state: PauseState) -> None:
self._pause_state = pause_state
def get_num_unfinished_requests(self) -> int:
if self._pause_state == PauseState.PAUSED_ALL:
return 0
if self._pause_state == PauseState.PAUSED_NEW:
return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
return num_waiting + len(self.running)
......
......@@ -172,9 +172,6 @@ class AsyncLLM(EngineClient):
)
self.logger_manager.log_engine_initialized()
# Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition()
self._paused = False
self._client_count = client_count
self.output_handler: asyncio.Task | None = None
......@@ -387,10 +384,6 @@ class AsyncLLM(EngineClient):
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
# Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id)
......@@ -741,7 +734,9 @@ class AsyncLLM(EngineClient):
"""
Pause generation to allow model weight updates.
New generation/encoding requests are blocked until resume.
All mode handling (abort / wait / keep) and cache clearing is done
in the engine. New generation/encoding requests will not be scheduled
until resume is called.
Args:
mode: How to handle in-flight requests:
......@@ -751,11 +746,8 @@ class AsyncLLM(EngineClient):
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED: use mode argument.
Whether to wait for in-flight requests to complete before pausing.
clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches).
"""
if wait_for_inflight_requests:
warnings.warn(
......@@ -766,50 +758,15 @@ class AsyncLLM(EngineClient):
stacklevel=2,
)
mode = "wait"
if mode == "keep":
# Freeze requests in the scheduler - they will resume on
# resume_generation().
await self.engine_core.pause_scheduler_async()
else:
if self._client_count > 1:
raise NotImplementedError(
"pause_generation is not supported with --api-server-count > 1"
" when mode is not 'keep'"
)
async with self._pause_cond:
if not self._paused:
self._paused = True
if mode == "abort":
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
elif mode == "wait":
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
else:
raise ValueError(f"Invalid mode: {mode}")
# Clear cache
if clear_cache:
await self.reset_prefix_cache(reset_running_requests=True)
await self.reset_mm_cache()
await self.reset_encoder_cache()
await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`."""
async with self._pause_cond:
await self.engine_core.resume_scheduler_async()
self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests
await self.engine_core.resume_scheduler_async()
async def is_paused(self) -> bool:
"""Return whether the engine is currently paused."""
async with self._pause_cond:
return self._paused
return await self.engine_core.is_scheduler_paused_async()
async def encode(
self,
......
......@@ -5,7 +5,7 @@ import queue
import signal
import threading
import time
from collections import deque
from collections import defaultdict, deque
from collections.abc import Callable, Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
......@@ -40,7 +40,7 @@ from vllm.v1.core.kv_cache_utils import (
get_request_block_hasher,
init_none_hash,
)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import (
EngineCoreOutput,
......@@ -48,6 +48,7 @@ from vllm.v1.engine import (
EngineCoreRequest,
EngineCoreRequestType,
FinishReason,
PauseMode,
ReconfigureDistributedRequest,
ReconfigureRankType,
UtilityOutput,
......@@ -210,8 +211,7 @@ class EngineCore:
self.aborts_queue = queue.Queue[list[str]]()
# Pause state for "keep" mode - freezes requests in queue.
self._scheduler_paused = False
self.per_step_hooks: set[Callable] = set()
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
......@@ -326,20 +326,6 @@ class EngineCore:
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
def pause_scheduler(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Requests are kept frozen in queue and can be resumed later.
"""
self._scheduler_paused = True
def resume_scheduler(self) -> None:
"""Resume the scheduler after a pause.
Resumes processing of frozen requests in the queue.
"""
self._scheduler_paused = False
@contextmanager
def log_error_detail(self, scheduler_output: SchedulerOutput):
"""Execute the model and log detailed info on failure."""
......@@ -393,10 +379,6 @@ class EngineCore:
was executed.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
......@@ -447,9 +429,6 @@ class EngineCore:
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
batch_queue = self.batch_queue
assert batch_queue is not None
......@@ -613,6 +592,20 @@ class EngineCore:
# Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache()
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future[Any] | None:
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
return None
def resume_scheduler(self) -> None:
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state. False in base EngineCore
and overridden in EngineCoreProc."""
return False
def sleep(self, level: int = 1):
"""Put the engine to sleep at the specified level.
......@@ -650,7 +643,7 @@ class EngineCore:
def is_sleeping(self) -> bool:
"""Check if engine is sleeping at any level."""
return self._scheduler_paused or self.model_executor.is_sleeping
return self.is_scheduler_paused() or self.model_executor.is_sleeping
def execute_dummy_batch(self):
self.model_executor.execute_dummy_batch()
......@@ -1053,13 +1046,9 @@ class EngineCoreProc(EngineCore):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
# Skip if scheduling is paused (level 0 sleep)
if not self._scheduler_paused:
self._process_engine_step()
else:
# When scheduling is paused, still need to check for wake up
# by processing any utility requests that might resume scheduling
pass
self._process_engine_step()
# 3) Run any per-step hooks.
self._process_per_step_hooks()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
......@@ -1067,9 +1056,9 @@ class EngineCoreProc(EngineCore):
waited = False
while (
not self.engines_running
and (not self.scheduler.has_requests() or self._scheduler_paused)
and not self.scheduler.has_requests()
and not self.batch_queue
and not self._scheduler_paused
and not self.per_step_hooks
):
if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue.
......@@ -1109,6 +1098,13 @@ class EngineCoreProc(EngineCore):
return model_executed
def _process_per_step_hooks(self) -> None:
if self.per_step_hooks:
for hook in list(self.per_step_hooks):
finished = hook(self)
if finished:
self.per_step_hooks.discard(hook)
def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any
) -> None:
......@@ -1122,18 +1118,14 @@ class EngineCoreProc(EngineCore):
elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
result = method(*self._convert_msgspec_args(method, args))
output.result = UtilityResult(result)
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (
f"Call to {method_name} method failed: {str(e)}"
)
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output))
# Lazily look-up utility method so that failure will be handled/returned.
get_result = lambda: (method := getattr(self, method_name)) and method(
*self._convert_msgspec_args(method, args)
)
enqueue_output = lambda out: self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=out))
)
self._invoke_utility_method(method_name, get_result, output, enqueue_output)
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
......@@ -1141,6 +1133,25 @@ class EngineCoreProc(EngineCore):
"Unrecognized input request type encountered: %s", request_type
)
@staticmethod
def _invoke_utility_method(
name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
):
try:
result = get_result()
if isinstance(result, Future):
# Defer utility output handling until future completion.
callback = lambda future: EngineCoreProc._invoke_utility_method(
name, future.result, output, enqueue_output
)
result.add_done_callback(callback)
return
output.result = UtilityResult(result)
except Exception as e:
logger.exception("Invocation of %s method failed", name)
output.failure_message = f"Call to {name} method failed: {str(e)}"
enqueue_output(output)
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
......@@ -1347,6 +1358,74 @@ class EngineCoreProc(EngineCore):
)
)
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future | None:
"""Pause generation; behavior depends on mode.
All pause states queue new adds. PAUSE_ABORT and PAUSE_KEEP skip step();
PAUSE_WAIT allows step() so in-flight requests can drain.
- ``abort``: Set PAUSE_ABORT, abort all requests, wait for abort
outputs to be sent (when running with output_queue), clear caches,
then complete the returned Future.
- ``wait``: Set PAUSE_WAIT (queue adds, keep stepping); when drained,
set PAUSE_KEEP, clear caches, complete the returned Future.
- ``keep``: Set PAUSE_KEEP; return a Future that completes when the
output queue is empty.
"""
if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}")
future: Future[Any] = Future()
def wait_until_idle(engine: "EngineCoreProc") -> bool:
scheduler = engine.scheduler
out_queue = engine.output_queue
if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
return False
if clear_cache:
engine.reset_prefix_cache(reset_running_requests=True)
engine.reset_mm_cache()
engine.reset_encoder_cache()
future.set_result(None)
return True
if mode == "abort":
aborted_reqs = self.scheduler.finish_requests(
None, RequestStatus.FINISHED_ABORTED
)
self._send_abort_outputs(aborted_reqs)
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state)
if not wait_until_idle(self):
self.per_step_hooks.add(wait_until_idle)
return future
return None
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
if aborted_reqs:
# Map client_index to list of request_ids that belong to that client.
by_client = defaultdict[int, set[str]](set)
for req_id, client_index in aborted_reqs:
by_client[client_index].add(req_id)
for client_index, req_ids in by_client.items():
outputs = [
EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
for req_id in req_ids
]
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco))
def resume_scheduler(self) -> None:
"""Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state."""
return self.scheduler.pause_state != PauseState.UNPAUSED
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
......@@ -1450,10 +1529,6 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# Skip processing if scheduling is paused (level 0 sleep)
if self._scheduler_paused:
continue
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
......
......@@ -36,6 +36,7 @@ from vllm.v1.engine import (
EngineCoreOutputs,
EngineCoreRequest,
EngineCoreRequestType,
PauseMode,
ReconfigureDistributedRequest,
ReconfigureRankType,
UtilityOutput,
......@@ -979,16 +980,17 @@ class AsyncMPClient(MPClient):
if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Blocks until the EngineCore acknowledges the pause.
"""
await self.call_utility_async("pause_scheduler")
async def pause_scheduler_async(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> None:
await self.call_utility_async("pause_scheduler", mode, clear_cache)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
await self.call_utility_async("resume_scheduler")
async def is_scheduler_paused_async(self) -> bool:
return await self.call_utility_async("is_scheduler_paused")
async def profile_async(
self, is_start: bool = True, profile_prefix: str | None = None
) -> None:
......@@ -1203,18 +1205,6 @@ class DPAsyncMPClient(AsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue."""
raise NotImplementedError(
"pause_scheduler_async is not yet supported for data parallel"
)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
raise NotImplementedError(
"resume_scheduler_async is not yet supported for data parallel"
)
class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment