Unverified Commit dddbff46 authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Core] Move pause and resume functions into engine (#34125)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Signed-off-by: default avatarhao-aaron <ahao@anyscale.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 47e9b63e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test pause/resume with Data Parallel (DP) via HTTP API.
This example demonstrates coordinated pause/resume across multiple DP ranks.
The pause synchronizes across all DP engines via all-reduce.
Prerequisites:
Start a vLLM server with data parallelism:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--data-parallel-size 4 \
--tensor-parallel-size 1
Then run this script:
$ python data_parallel_pause_resume.py
The test verifies pause works by:
1. Starting a streaming generation request
2. Pausing the server mid-generation
3. Sleeping for PAUSE_DURATION seconds
4. Resuming the server
5. Verifying there was a gap in token generation matching the pause duration
"""
import argparse
import threading
import time
import requests
from openai import OpenAI
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
PAUSE_DURATION = 3.0
def pause_generation(base_url: str, mode: str = "keep") -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, params={"mode": mode}, timeout=60)
response.raise_for_status()
print("Server paused")
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
print("Server resumed")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", default=BASE_URL)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
client = OpenAI(
base_url=f"{args.base_url}/v1",
api_key="EMPTY",
)
prompt = "Write a long story about a dragon. Once upon a time"
token_times: list[float] = []
pause_token_idx = 0
pause_triggered = threading.Event()
def generator_thread():
"""Stream tokens and record timestamps."""
stream = client.completions.create(
model=args.model,
prompt=prompt,
max_tokens=50,
stream=True,
)
for chunk in stream:
if chunk.choices[0].text:
token_times.append(time.monotonic())
token_count = len(token_times)
print(f"Token {token_count}: {chunk.choices[0].text!r}")
# Signal controller after some tokens
if token_count >= 5 and not pause_triggered.is_set():
pause_triggered.set()
def controller_thread():
"""Pause and resume the server."""
nonlocal pause_token_idx
# Wait for some tokens
pause_triggered.wait()
print(f"\nPausing server (keep mode) at token {len(token_times)}...")
pause_generation(args.base_url, mode="keep")
pause_token_idx = len(token_times)
print(f"Sleeping for {PAUSE_DURATION}s...")
time.sleep(PAUSE_DURATION)
print("Resuming server...")
resume_generation(args.base_url)
print("Resumed!\n")
# Run both threads
gen_thread = threading.Thread(target=generator_thread)
ctrl_thread = threading.Thread(target=controller_thread)
gen_thread.start()
ctrl_thread.start()
gen_thread.join()
ctrl_thread.join()
# Check gap at the pause point
if pause_token_idx < len(token_times):
pause_gap = token_times[pause_token_idx] - token_times[pause_token_idx - 1]
print(
f"\nGap after pause (token {pause_token_idx} -> "
f"{pause_token_idx + 1}): {pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print("Test passed! Pause synchronized across DP ranks.")
else:
print(f"Test failed! Expected ~{PAUSE_DURATION}s gap, got {pause_gap:.3f}s")
else:
print("Test failed! No tokens were generated after resuming.")
if __name__ == "__main__":
main()
...@@ -12,6 +12,7 @@ from vllm import SamplingParams ...@@ -12,6 +12,7 @@ from vllm import SamplingParams
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
...@@ -181,3 +182,145 @@ async def test_load( ...@@ -181,3 +182,145 @@ async def test_load(
assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), ( assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), (
f"requests are imbalanced: {stats_loggers}" f"requests are imbalanced: {stats_loggers}"
) )
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
@pytest.mark.asyncio
async def test_dp_pause_resume_basic():
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
assert not await engine.is_paused()
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Engine still works after resume
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="after-resume",
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_dp_pause_abort():
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
# Start several requests so they are distributed across ranks
sampling_params = SamplingParams(max_tokens=500, ignore_eos=True)
num_requests = 4
outputs_by_id: dict[str, list[RequestOutput]] = {}
async def gen(rid: str):
out_list: list[RequestOutput] = []
outputs_by_id[rid] = out_list
async for out in engine.generate(
request_id=rid,
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
out_list.append(out)
return out_list[-1] if out_list else None
tasks = [asyncio.create_task(gen(f"req-{i}")) for i in range(num_requests)]
# Wait for some tokens on at least one request
while not any(len(o) >= 2 for o in outputs_by_id.values()):
await asyncio.sleep(0.02)
await engine.pause_generation(mode="abort")
finals = await asyncio.gather(*tasks)
for i, final in enumerate(finals):
assert final is not None, f"req-{i} had no output"
assert final.finished
assert final.outputs[0].finish_reason == "abort"
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# New request completes after resume
async for out in engine.generate(
request_id="after-abort",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
assert out.finished
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.asyncio
async def test_dp_pause_keep_then_resume():
"""Pause with keep queues new requests; resume allows them to run."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
request_done = asyncio.Event()
async def gen():
async for out in engine.generate(
request_id="queued-keep",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_done.set()
return out
task = asyncio.create_task(gen())
await asyncio.sleep(0.2)
assert not request_done.is_set()
await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused()
...@@ -708,8 +708,6 @@ async def test_pause_resume_basic(): ...@@ -708,8 +708,6 @@ async def test_pause_resume_basic():
# Test all modes with no requests in flight # Test all modes with no requests in flight
for mode in ("abort", "wait", "keep"): for mode in ("abort", "wait", "keep"):
await engine.pause_generation(mode=mode) await engine.pause_generation(mode=mode)
# "keep" only freezes the scheduler; it does not set _paused
if mode != "keep":
assert await engine.is_paused() assert await engine.is_paused()
await engine.resume_generation() await engine.resume_generation()
assert not await engine.is_paused() assert not await engine.is_paused()
...@@ -808,6 +806,53 @@ async def test_pause_abort(): ...@@ -808,6 +806,53 @@ async def test_pause_abort():
assert final_output2.finished assert final_output2.finished
@pytest.mark.asyncio
async def test_pause_then_abort_queued_request():
"""Test that aborting a request that was submitted while paused (in
_paused_adds_queue) aborts it and notifies the client; the request does
not run after resume.
"""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
request_id = "abort-queued-request"
sampling_params = SamplingParams(max_tokens=20, ignore_eos=True)
outputs: list[RequestOutput] = []
# Pause first so the next add goes to _paused_adds_queue
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
async def gen():
async for out in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(out)
return outputs[-1] if outputs else None
gen_task = asyncio.create_task(gen())
# Give the request time to reach the engine and sit in _paused_adds_queue
await asyncio.sleep(0.2)
# Abort the queued request
await engine.abort(request_id, internal=False)
# Resume so the engine can process and deliver the abort output
await engine.resume_generation()
final_output = await asyncio.wait_for(gen_task, timeout=10.0)
assert final_output is not None
assert final_output.finished
assert final_output.outputs[0].finish_reason == "abort"
# Request was never run, so no tokens
assert len(final_output.outputs[0].token_ids) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pause_wait(): async def test_pause_wait():
"""Test that mode='wait' waits for in-flight requests to complete.""" """Test that mode='wait' waits for in-flight requests to complete."""
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import signal import signal
import time import time
import uuid import uuid
from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from threading import Thread from threading import Thread
from types import SimpleNamespace from types import SimpleNamespace
...@@ -278,6 +279,24 @@ def echo_dc_nested( ...@@ -278,6 +279,24 @@ def echo_dc_nested(
return structures.get(structure_type, val) return structures.get(structure_type, val)
def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
"""Utility that returns a Future completed by a per_step_hook after
num_wait_loops engine steps (tests deferred utility path).
"""
future: Future = Future()
remaining = [num_wait_loops]
def _step(engine: EngineCore) -> bool:
remaining[0] -= 1
if remaining[0] <= 0:
future.set_result(value)
return True # remove hook
return False
self.per_step_hooks.add(_step)
return future
# --- Fixtures for subprocess patching --- # --- Fixtures for subprocess patching ---
# These create sitecustomize.py files that patch EngineCore in spawned # These create sitecustomize.py files that patch EngineCore in spawned
# subprocesses. This is necessary because ROCm requires 'spawn' multiprocessing # subprocesses. This is necessary because ROCm requires 'spawn' multiprocessing
...@@ -383,6 +402,28 @@ def subprocess_echo_dc_nested_patch(monkeypatch, tmp_path): ...@@ -383,6 +402,28 @@ def subprocess_echo_dc_nested_patch(monkeypatch, tmp_path):
) )
@pytest.fixture
def subprocess_future_echo_patch(monkeypatch, tmp_path):
"""Create sitecustomize.py so spawned subprocesses have future_echo method."""
sc = tmp_path / "sitecustomize.py"
sc.write_text(
"\n".join(
[
"from concurrent.futures import Future",
"from typing import Any",
"",
"from vllm.v1.engine.core import EngineCore",
inspect.getsource(future_echo),
"EngineCore.future_echo = future_echo",
]
)
)
monkeypatch.setenv(
"PYTHONPATH",
os.pathsep.join(filter(None, [str(tmp_path), os.getenv("PYTHONPATH")])),
)
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client( def test_engine_core_client(
...@@ -786,6 +827,48 @@ async def test_engine_core_client_util_method_nested_structures( ...@@ -786,6 +827,48 @@ async def test_engine_core_client_util_method_nested_structures(
client.shutdown() client.shutdown()
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_future_utility_async(
monkeypatch: pytest.MonkeyPatch,
subprocess_future_echo_patch,
):
"""Test that a utility returning a Future (completed by a per_step_hook
after N steps) completes when the future is done (engine uses add_done_callback).
"""
with monkeypatch.context() as m:
m.setattr(EngineCore, "future_echo", future_echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
core_client: AsyncMPClient = client
# Completes after 2 engine steps (num_wait_loops=2)
result = await core_client.call_utility_async(
"future_echo", "future_result", 2
)
assert result == "future_result"
# None is a valid result (num_wait_loops=0 → completes on first step)
result = await core_client.call_utility_async("future_echo", None, 0)
assert result is None
finally:
client.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"multiprocessing_mode,publisher_config", "multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")], [(True, "tcp"), (False, "inproc")],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -18,6 +19,20 @@ if TYPE_CHECKING: ...@@ -18,6 +19,20 @@ if TYPE_CHECKING:
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
class PauseState(enum.IntEnum):
"""Scheduler pause state.
- UNPAUSED: Normal operation
- PAUSE_NEW: No new requests are scheduled, requests already in
running state are scheduled.
- PAUSE_ALL: No requests are scheduled
"""
UNPAUSED = 0
PAUSED_NEW = 1
PAUSED_ALL = 2
class SchedulerInterface(ABC): class SchedulerInterface(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
...@@ -120,11 +135,11 @@ class SchedulerInterface(ABC): ...@@ -120,11 +135,11 @@ class SchedulerInterface(ABC):
@abstractmethod @abstractmethod
def finish_requests( def finish_requests(
self, self,
request_ids: str | Iterable[str], request_ids: str | Iterable[str] | None,
finished_status: "RequestStatus", finished_status: "RequestStatus",
) -> None: ) -> list[tuple[str, int]]:
"""Finish the requests in the scheduler's internal queue. If the request """Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing. is not in the queue, this method will do nothing for that request.
This method is called in two cases: This method is called in two cases:
1. When the request is aborted by the client. 1. When the request is aborted by the client.
...@@ -132,8 +147,12 @@ class SchedulerInterface(ABC): ...@@ -132,8 +147,12 @@ class SchedulerInterface(ABC):
de-tokenizing its generated tokens. de-tokenizing its generated tokens.
Args: Args:
request_ids: A single or a list of request IDs. request_ids: A single or a list of request IDs, or None to finish all.
finished_status: The finished status of the given requests. finished_status: The finished status of the given requests.
Returns:
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -167,6 +186,16 @@ class SchedulerInterface(ABC): ...@@ -167,6 +186,16 @@ class SchedulerInterface(ABC):
not yet returned in SchedulerOutputs.""" not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests() return self.has_unfinished_requests() or self.has_finished_requests()
@property
@abstractmethod
def pause_state(self) -> PauseState:
"""Current pause state of the scheduler."""
raise NotImplementedError
@abstractmethod
def set_pause_state(self, pause_state: PauseState) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def reset_prefix_cache( def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False self, reset_running_requests: bool = False, reset_connector: bool = False
......
...@@ -38,7 +38,7 @@ from vllm.v1.core.encoder_cache_manager import ( ...@@ -38,7 +38,7 @@ from vllm.v1.core.encoder_cache_manager import (
) )
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import ( from vllm.v1.core.sched.output import (
CachedRequestData, CachedRequestData,
GrammarOutput, GrammarOutput,
...@@ -271,6 +271,8 @@ class Scheduler(SchedulerInterface): ...@@ -271,6 +271,8 @@ class Scheduler(SchedulerInterface):
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
) )
self._pause_state: PauseState = PauseState.UNPAUSED
def _mamba_block_aligned_split( def _mamba_block_aligned_split(
self, self,
request: Request, request: Request,
...@@ -341,6 +343,10 @@ class Scheduler(SchedulerInterface): ...@@ -341,6 +343,10 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks: dict[str, KVCacheBlocks] = {} req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
if self._pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0
# Encoder-related. # Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {} scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens encoder_compute_budget = self.max_num_encoder_input_tokens
...@@ -530,12 +536,12 @@ class Scheduler(SchedulerInterface): ...@@ -530,12 +536,12 @@ class Scheduler(SchedulerInterface):
) )
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be # Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later # skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy) skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
...@@ -802,6 +808,7 @@ class Scheduler(SchedulerInterface): ...@@ -802,6 +808,7 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i) self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests) self.waiting.prepend_requests(skipped_waiting_requests)
...@@ -1672,18 +1679,26 @@ class Scheduler(SchedulerInterface): ...@@ -1672,18 +1679,26 @@ class Scheduler(SchedulerInterface):
request.record_event(EngineCoreEventType.QUEUED) request.record_event(EngineCoreEventType.QUEUED)
def finish_requests( def finish_requests(
self, request_ids: str | Iterable[str], finished_status: RequestStatus self, request_ids: str | Iterable[str] | None, finished_status: RequestStatus
) -> None: ) -> list[tuple[str, int]]:
"""Handles the finish signal from outside the scheduler. """Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client For example, the API server can abort a request when the client
disconnects. disconnects.
If request_ids is None, all requests will be finished.
Returns:
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
""" """
assert RequestStatus.is_finished(finished_status) assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str): if isinstance(request_ids, str):
request_ids = (request_ids,) request_ids = (request_ids,)
else: elif request_ids is not None:
request_ids = set(request_ids) request_ids = set(request_ids)
else:
request_ids = self.requests.keys()
running_requests_to_remove = set() running_requests_to_remove = set()
waiting_requests_to_remove = [] waiting_requests_to_remove = []
...@@ -1723,6 +1738,8 @@ class Scheduler(SchedulerInterface): ...@@ -1723,6 +1738,8 @@ class Scheduler(SchedulerInterface):
request.status = finished_status request.status = finished_status
self._free_request(request, delay_free_blocks=delay_free_blocks) self._free_request(request, delay_free_blocks=delay_free_blocks)
return [(r.request_id, r.client_index) for r in valid_requests]
def _free_request( def _free_request(
self, request: Request, delay_free_blocks: bool = False self, request: Request, delay_free_blocks: bool = False
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
...@@ -1746,7 +1763,18 @@ class Scheduler(SchedulerInterface): ...@@ -1746,7 +1763,18 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(request) self.kv_cache_manager.free(request)
del self.requests[request.request_id] del self.requests[request.request_id]
@property
def pause_state(self) -> PauseState:
return self._pause_state
def set_pause_state(self, pause_state: PauseState) -> None:
self._pause_state = pause_state
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
if self._pause_state == PauseState.PAUSED_ALL:
return 0
if self._pause_state == PauseState.PAUSED_NEW:
return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
return num_waiting + len(self.running) return num_waiting + len(self.running)
......
...@@ -172,9 +172,6 @@ class AsyncLLM(EngineClient): ...@@ -172,9 +172,6 @@ class AsyncLLM(EngineClient):
) )
self.logger_manager.log_engine_initialized() self.logger_manager.log_engine_initialized()
# Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition()
self._paused = False
self._client_count = client_count self._client_count = client_count
self.output_handler: asyncio.Task | None = None self.output_handler: asyncio.Task | None = None
...@@ -387,10 +384,6 @@ class AsyncLLM(EngineClient): ...@@ -387,10 +384,6 @@ class AsyncLLM(EngineClient):
# to handle startup failure gracefully in the OpenAI server. # to handle startup failure gracefully in the OpenAI server.
self._run_output_handler() self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
# Create a new output collector for the request. # Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id) queue = RequestOutputCollector(params.output_kind, request.request_id)
...@@ -741,7 +734,9 @@ class AsyncLLM(EngineClient): ...@@ -741,7 +734,9 @@ class AsyncLLM(EngineClient):
""" """
Pause generation to allow model weight updates. Pause generation to allow model weight updates.
New generation/encoding requests are blocked until resume. All mode handling (abort / wait / keep) and cache clearing is done
in the engine. New generation/encoding requests will not be scheduled
until resume is called.
Args: Args:
mode: How to handle in-flight requests: mode: How to handle in-flight requests:
...@@ -751,11 +746,8 @@ class AsyncLLM(EngineClient): ...@@ -751,11 +746,8 @@ class AsyncLLM(EngineClient):
- ``"keep"``: Freeze requests in queue; they resume on - ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`. :meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED: use mode argument. wait_for_inflight_requests: DEPRECATED: use mode argument.
Whether to wait for in-flight requests to complete before pausing.
clear_cache: Whether to clear KV cache and prefix cache after clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume. draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches).
""" """
if wait_for_inflight_requests: if wait_for_inflight_requests:
warnings.warn( warnings.warn(
...@@ -766,50 +758,15 @@ class AsyncLLM(EngineClient): ...@@ -766,50 +758,15 @@ class AsyncLLM(EngineClient):
stacklevel=2, stacklevel=2,
) )
mode = "wait" mode = "wait"
await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
if mode == "keep":
# Freeze requests in the scheduler - they will resume on
# resume_generation().
await self.engine_core.pause_scheduler_async()
else:
if self._client_count > 1:
raise NotImplementedError(
"pause_generation is not supported with --api-server-count > 1"
" when mode is not 'keep'"
)
async with self._pause_cond:
if not self._paused:
self._paused = True
if mode == "abort":
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
elif mode == "wait":
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
else:
raise ValueError(f"Invalid mode: {mode}")
# Clear cache
if clear_cache:
await self.reset_prefix_cache(reset_running_requests=True)
await self.reset_mm_cache()
await self.reset_encoder_cache()
async def resume_generation(self) -> None: async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`.""" """Resume generation after :meth:`pause_generation`."""
async with self._pause_cond:
await self.engine_core.resume_scheduler_async() await self.engine_core.resume_scheduler_async()
self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests
async def is_paused(self) -> bool: async def is_paused(self) -> bool:
"""Return whether the engine is currently paused.""" """Return whether the engine is currently paused."""
return await self.engine_core.is_scheduler_paused_async()
async with self._pause_cond:
return self._paused
async def encode( async def encode(
self, self,
......
...@@ -5,7 +5,7 @@ import queue ...@@ -5,7 +5,7 @@ import queue
import signal import signal
import threading import threading
import time import time
from collections import deque from collections import defaultdict, deque
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import Future from concurrent.futures import Future
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
...@@ -40,7 +40,7 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -40,7 +40,7 @@ from vllm.v1.core.kv_cache_utils import (
get_request_block_hasher, get_request_block_hasher,
init_none_hash, init_none_hash,
) )
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ( from vllm.v1.engine import (
EngineCoreOutput, EngineCoreOutput,
...@@ -48,6 +48,7 @@ from vllm.v1.engine import ( ...@@ -48,6 +48,7 @@ from vllm.v1.engine import (
EngineCoreRequest, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
FinishReason, FinishReason,
PauseMode,
ReconfigureDistributedRequest, ReconfigureDistributedRequest,
ReconfigureRankType, ReconfigureRankType,
UtilityOutput, UtilityOutput,
...@@ -210,8 +211,7 @@ class EngineCore: ...@@ -210,8 +211,7 @@ class EngineCore:
self.aborts_queue = queue.Queue[list[str]]() self.aborts_queue = queue.Queue[list[str]]()
# Pause state for "keep" mode - freezes requests in queue. self.per_step_hooks: set[Callable] = set()
self._scheduler_paused = False
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
...@@ -326,20 +326,6 @@ class EngineCore: ...@@ -326,20 +326,6 @@ class EngineCore:
# (i.e. client-aborted vs stop criteria met). # (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
def pause_scheduler(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Requests are kept frozen in queue and can be resumed later.
"""
self._scheduler_paused = True
def resume_scheduler(self) -> None:
"""Resume the scheduler after a pause.
Resumes processing of frozen requests in the queue.
"""
self._scheduler_paused = False
@contextmanager @contextmanager
def log_error_detail(self, scheduler_output: SchedulerOutput): def log_error_detail(self, scheduler_output: SchedulerOutput):
"""Execute the model and log detailed info on failure.""" """Execute the model and log detailed info on failure."""
...@@ -393,10 +379,6 @@ class EngineCore: ...@@ -393,10 +379,6 @@ class EngineCore:
was executed. was executed.
""" """
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
# Check for any requests remaining in the scheduler - unfinished, # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch. # or finished and not yet removed from the batch.
if not self.scheduler.has_requests(): if not self.scheduler.has_requests():
...@@ -447,9 +429,6 @@ class EngineCore: ...@@ -447,9 +429,6 @@ class EngineCore:
batch in the job queue is finished. batch in the job queue is finished.
3. Update the scheduler from the output. 3. Update the scheduler from the output.
""" """
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
batch_queue = self.batch_queue batch_queue = self.batch_queue
assert batch_queue is not None assert batch_queue is not None
...@@ -613,6 +592,20 @@ class EngineCore: ...@@ -613,6 +592,20 @@ class EngineCore:
# Reset the GPU model runner's encoder cache (physical storage) # Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache() self.model_executor.reset_encoder_cache()
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future[Any] | None:
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
return None
def resume_scheduler(self) -> None:
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state. False in base EngineCore
and overridden in EngineCoreProc."""
return False
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
"""Put the engine to sleep at the specified level. """Put the engine to sleep at the specified level.
...@@ -650,7 +643,7 @@ class EngineCore: ...@@ -650,7 +643,7 @@ class EngineCore:
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
"""Check if engine is sleeping at any level.""" """Check if engine is sleeping at any level."""
return self._scheduler_paused or self.model_executor.is_sleeping return self.is_scheduler_paused() or self.model_executor.is_sleeping
def execute_dummy_batch(self): def execute_dummy_batch(self):
self.model_executor.execute_dummy_batch() self.model_executor.execute_dummy_batch()
...@@ -1053,13 +1046,9 @@ class EngineCoreProc(EngineCore): ...@@ -1053,13 +1046,9 @@ class EngineCoreProc(EngineCore):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core and return the outputs. # 2) Step the engine core and return the outputs.
# Skip if scheduling is paused (level 0 sleep)
if not self._scheduler_paused:
self._process_engine_step() self._process_engine_step()
else: # 3) Run any per-step hooks.
# When scheduling is paused, still need to check for wake up self._process_per_step_hooks()
# by processing any utility requests that might resume scheduling
pass
def _process_input_queue(self): def _process_input_queue(self):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
...@@ -1067,9 +1056,9 @@ class EngineCoreProc(EngineCore): ...@@ -1067,9 +1056,9 @@ class EngineCoreProc(EngineCore):
waited = False waited = False
while ( while (
not self.engines_running not self.engines_running
and (not self.scheduler.has_requests() or self._scheduler_paused) and not self.scheduler.has_requests()
and not self.batch_queue and not self.batch_queue
and not self._scheduler_paused and not self.per_step_hooks
): ):
if self.input_queue.empty(): if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue. # Drain aborts queue; all aborts are also processed via input_queue.
...@@ -1109,6 +1098,13 @@ class EngineCoreProc(EngineCore): ...@@ -1109,6 +1098,13 @@ class EngineCoreProc(EngineCore):
return model_executed return model_executed
def _process_per_step_hooks(self) -> None:
if self.per_step_hooks:
for hook in list(self.per_step_hooks):
finished = hook(self)
if finished:
self.per_step_hooks.discard(hook)
def _handle_client_request( def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any self, request_type: EngineCoreRequestType, request: Any
) -> None: ) -> None:
...@@ -1122,18 +1118,14 @@ class EngineCoreProc(EngineCore): ...@@ -1122,18 +1118,14 @@ class EngineCoreProc(EngineCore):
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
try: # Lazily look-up utility method so that failure will be handled/returned.
method = getattr(self, method_name) get_result = lambda: (method := getattr(self, method_name)) and method(
result = method(*self._convert_msgspec_args(method, args)) *self._convert_msgspec_args(method, args)
output.result = UtilityResult(result)
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (
f"Call to {method_name} method failed: {str(e)}"
) )
self.output_queue.put_nowait( enqueue_output = lambda out: self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output)) (client_idx, EngineCoreOutputs(utility_output=out))
) )
self._invoke_utility_method(method_name, get_result, output, enqueue_output)
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.") raise RuntimeError("Executor failed.")
else: else:
...@@ -1141,6 +1133,25 @@ class EngineCoreProc(EngineCore): ...@@ -1141,6 +1133,25 @@ class EngineCoreProc(EngineCore):
"Unrecognized input request type encountered: %s", request_type "Unrecognized input request type encountered: %s", request_type
) )
@staticmethod
def _invoke_utility_method(
name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
):
try:
result = get_result()
if isinstance(result, Future):
# Defer utility output handling until future completion.
callback = lambda future: EngineCoreProc._invoke_utility_method(
name, future.result, output, enqueue_output
)
result.add_done_callback(callback)
return
output.result = UtilityResult(result)
except Exception as e:
logger.exception("Invocation of %s method failed", name)
output.failure_message = f"Call to {name} method failed: {str(e)}"
enqueue_output(output)
@staticmethod @staticmethod
def _convert_msgspec_args(method, args): def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method """If a provided arg type doesn't match corresponding target method
...@@ -1347,6 +1358,74 @@ class EngineCoreProc(EngineCore): ...@@ -1347,6 +1358,74 @@ class EngineCoreProc(EngineCore):
) )
) )
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future | None:
"""Pause generation; behavior depends on mode.
All pause states queue new adds. PAUSE_ABORT and PAUSE_KEEP skip step();
PAUSE_WAIT allows step() so in-flight requests can drain.
- ``abort``: Set PAUSE_ABORT, abort all requests, wait for abort
outputs to be sent (when running with output_queue), clear caches,
then complete the returned Future.
- ``wait``: Set PAUSE_WAIT (queue adds, keep stepping); when drained,
set PAUSE_KEEP, clear caches, complete the returned Future.
- ``keep``: Set PAUSE_KEEP; return a Future that completes when the
output queue is empty.
"""
if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}")
future: Future[Any] = Future()
def wait_until_idle(engine: "EngineCoreProc") -> bool:
scheduler = engine.scheduler
out_queue = engine.output_queue
if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
return False
if clear_cache:
engine.reset_prefix_cache(reset_running_requests=True)
engine.reset_mm_cache()
engine.reset_encoder_cache()
future.set_result(None)
return True
if mode == "abort":
aborted_reqs = self.scheduler.finish_requests(
None, RequestStatus.FINISHED_ABORTED
)
self._send_abort_outputs(aborted_reqs)
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state)
if not wait_until_idle(self):
self.per_step_hooks.add(wait_until_idle)
return future
return None
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
if aborted_reqs:
# Map client_index to list of request_ids that belong to that client.
by_client = defaultdict[int, set[str]](set)
for req_id, client_index in aborted_reqs:
by_client[client_index].add(req_id)
for client_index, req_ids in by_client.items():
outputs = [
EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
for req_id in req_ids
]
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco))
def resume_scheduler(self) -> None:
"""Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state."""
return self.scheduler.pause_state != PauseState.UNPAUSED
class DPEngineCoreProc(EngineCoreProc): class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process """ZMQ-wrapper for running EngineCore in background process
...@@ -1450,10 +1529,6 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1450,10 +1529,6 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# Skip processing if scheduling is paused (level 0 sleep)
if self._scheduler_paused:
continue
# 2) Step the engine core. # 2) Step the engine core.
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
......
...@@ -36,6 +36,7 @@ from vllm.v1.engine import ( ...@@ -36,6 +36,7 @@ from vllm.v1.engine import (
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
PauseMode,
ReconfigureDistributedRequest, ReconfigureDistributedRequest,
ReconfigureRankType, ReconfigureRankType,
UtilityOutput, UtilityOutput,
...@@ -979,16 +980,17 @@ class AsyncMPClient(MPClient): ...@@ -979,16 +980,17 @@ class AsyncMPClient(MPClient):
if request_ids and not self.resources.engine_dead: if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(self) -> None: async def pause_scheduler_async(
"""Pause the scheduler, keeping requests frozen in queue. self, mode: PauseMode = "abort", clear_cache: bool = True
Blocks until the EngineCore acknowledges the pause. ) -> None:
""" await self.call_utility_async("pause_scheduler", mode, clear_cache)
await self.call_utility_async("pause_scheduler")
async def resume_scheduler_async(self) -> None: async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
await self.call_utility_async("resume_scheduler") await self.call_utility_async("resume_scheduler")
async def is_scheduler_paused_async(self) -> bool:
return await self.call_utility_async("is_scheduler_paused")
async def profile_async( async def profile_async(
self, is_start: bool = True, profile_prefix: str | None = None self, is_start: bool = True, profile_prefix: str | None = None
) -> None: ) -> None:
...@@ -1203,18 +1205,6 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1203,18 +1205,6 @@ class DPAsyncMPClient(AsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest): def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine return self.core_engine
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue."""
raise NotImplementedError(
"pause_scheduler_async is not yet supported for data parallel"
)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
raise NotImplementedError(
"resume_scheduler_async is not yet supported for data parallel"
)
class DPLBAsyncMPClient(DPAsyncMPClient): class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel) """Asyncio-compatible client for multi-proc, multi-engine (data parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment