Unverified Commit 89a385d7 authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Feat][RL] Pause and Resume with keep requests for single engine (#32351)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 4a2d00ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test for pause/resume with keep mode.
This test uses concurrent tasks to verify the engine truly stops generating
during pause:
1. Generator task: continuously generates and logs time between tokens
2. Controller task: sends pause/resume commands
If the engine properly pauses, we should see a gap in token timestamps
matching the pause duration.
"""
import asyncio
import time
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
PAUSE_DURATION = 3.0 # seconds
async def main():
# Create engine with a small model
engine_args = AsyncEngineArgs(
model="facebook/opt-125m",
enforce_eager=True,
)
engine = AsyncLLM.from_engine_args(engine_args)
prompt = "Write a story about a dragon. Once upon a time"
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
# Track token arrival times
token_times: list[tuple[int, float]] = [] # (token_count, timestamp)
pause_time: float = 0
resume_time: float = 0
pause_token_idx: int = 0 # Index in token_times when pause occurred
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-req",
prompt=prompt,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
print(
f"Token {token_count} arrived:"
f"T={token_times[-1][1] - token_times[0][1]:.3f}s"
)
return output
async def controller_task():
"""Pause and resume the engine after some tokens generated."""
nonlocal pause_time, resume_time, pause_token_idx
# Wait for some tokens to be generated
while len(token_times) < 5:
await asyncio.sleep(0.01)
print(f"\nPausing engine (keep mode) at token {len(token_times)}")
pause_time = time.monotonic()
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
print(f"Paused! Sleeping for {PAUSE_DURATION}s...")
# Sleep while paused - no tokens should be generated during this time
await asyncio.sleep(PAUSE_DURATION)
print("Resuming engine...")
await engine.resume_generation()
resume_time = time.monotonic()
print("Resumed!\n")
# Run both tasks concurrently
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
# Verify the pause actually stopped generation.
# The gap after the pause token should be approximately the sleep duration.
pause_gap = token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
print(
f"\nGap after pause (token {pause_token_idx - 1} -> {pause_token_idx}): "
f"{pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print(f"✓ Test passed! Engine paused for ~{pause_gap:.1f}s")
else:
print(
f"✗ Test failed! Expected ~{PAUSE_DURATION}s gap after pause, "
f"got {pause_gap:.3f}s"
)
raise AssertionError("Engine did not properly pause")
# Verify request completed
assert final_output.finished, "Request should have finished"
assert len(final_output.outputs[0].token_ids) == 30, "Should have all tokens"
engine.shutdown()
if __name__ == "__main__":
asyncio.run(main())
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import time
from contextlib import ExitStack from contextlib import ExitStack
from unittest.mock import MagicMock from unittest.mock import MagicMock
...@@ -661,3 +662,301 @@ async def collect_outputs( ...@@ -661,3 +662,301 @@ async def collect_outputs(
outputs_list.append(output) outputs_list.append(output)
final_output = output final_output = output
return final_output return final_output
# =============================================================================
# Pause/Resume Tests
# =============================================================================
@pytest.mark.asyncio
async def test_pause_resume_basic():
"""Test basic pause/resume flag behavior and idempotency.
Tests:
- pause_generation sets the paused flag
- resume_generation clears the paused flag
- calling pause when already paused is a no-op
- calling resume when not paused is safe
- all pause modes work with no requests in flight
- rapid pause/resume cycles don't break the engine
"""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Initially not paused
assert not await engine.is_paused()
# Resume when not paused should be safe
await engine.resume_generation()
assert not await engine.is_paused()
# Pause sets flag
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Pause when already paused is a no-op
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Resume clears flag
await engine.resume_generation()
assert not await engine.is_paused()
# Test all modes with no requests in flight
for mode in ("abort", "wait", "keep"):
await engine.pause_generation(mode=mode)
# "keep" only freezes the scheduler; it does not set _paused
if mode != "keep":
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Concurrent pause/resume race conditions - should not deadlock or raise
await asyncio.gather(
engine.pause_generation(mode="abort"),
engine.resume_generation(),
engine.pause_generation(mode="abort"),
engine.resume_generation(),
)
# Ensure we end in a known state
await engine.resume_generation()
assert not await engine.is_paused()
# Engine should still work after all cycles
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="post-cycles",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_pause_abort():
"""Test that mode='abort' aborts in-flight requests immediately."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a long-running request
sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
outputs: list[RequestOutput] = []
async def gen():
async for out in engine.generate(
request_id="test-abort-pause",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(out)
return outputs[-1] if outputs else None
# Start generation task
gen_task = asyncio.create_task(gen())
# Wait for some tokens to be generated
while len(outputs) < 3:
await asyncio.sleep(0.01)
# Pause with abort mode
await engine.pause_generation(mode="abort")
# Wait for task to complete (should be aborted)
final_output = await gen_task
# Request should be finished (aborted)
assert final_output is not None
assert final_output.finished
assert final_output.outputs[0].finish_reason == "abort"
# Also test that new requests are blocked while paused, then resume
assert await engine.is_paused()
request_completed = False
async def gen_blocked():
nonlocal request_completed
async for out in engine.generate(
request_id="test-blocked",
prompt=TEXT_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_completed = True
return out
# Start a request (should block)
gen_task2 = asyncio.create_task(gen_blocked())
# Wait a bit - request should not have completed
await asyncio.sleep(0.3)
assert not request_completed, "Request should be blocked while paused"
# Resume
await engine.resume_generation()
# Now request should complete
final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
assert request_completed
assert final_output2.finished
@pytest.mark.asyncio
async def test_pause_wait():
"""Test that mode='wait' waits for in-flight requests to complete."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a request - use fewer tokens since wait mode waits for completion
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
got_first_token = asyncio.Event()
request_completed = False
async def gen():
nonlocal request_completed
async for out in engine.generate(
request_id="test-wait",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
got_first_token.set()
request_completed = True
return out
# Start generation
gen_task = asyncio.create_task(gen())
# Wait for generation to start (event-driven)
await asyncio.wait_for(got_first_token.wait(), timeout=30.0)
# Pause with wait mode - should wait for request to finish
await engine.pause_generation(mode="wait")
# By now the request should be done (wait mode waits for completion)
assert request_completed, "Request should have completed during wait"
final_output = gen_task.result()
assert final_output.finished
# Should complete normally, not aborted
assert final_output.outputs[0].finish_reason != "eos"
@pytest.mark.asyncio
async def test_pause_keep_single_request():
"""Test that mode='keep' freezes a single request and resumes with timing gap."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
token_times: list[tuple[int, float]] = []
pause_duration = 5.0
pause_token_idx = 0
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-keep-single",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
return output
async def controller_task():
"""Pause and resume the engine."""
nonlocal pause_token_idx
# Wait for some tokens (event-driven, handles slow token generation)
while len(token_times) < 5:
await asyncio.sleep(0.01)
# Pause with keep mode
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
# Sleep while paused
await asyncio.sleep(pause_duration)
# Resume
await engine.resume_generation()
# Run both tasks with timeout for slow generation
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.wait_for(
asyncio.gather(gen_task, ctrl_task), timeout=60.0
)
# Request should complete with all tokens
assert final_output.finished
assert len(final_output.outputs[0].token_ids) == 30
# Check the gap at the recorded pause index matches the pause duration
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)
@pytest.mark.asyncio
async def test_pause_keep_multi_request():
"""Test that mode='keep' freezes multiple concurrent requests and all resume."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
num_requests = 3
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
completed_requests: list[str] = []
any_token_generated = asyncio.Event()
async def gen_multi(request_id: str):
async for out in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
any_token_generated.set()
completed_requests.append(request_id)
return out
# Start multiple requests
tasks = [
asyncio.create_task(gen_multi(f"req-multi-{i}"))
for i in range(num_requests)
]
# Wait for at least one token across any request (event-driven)
await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)
# Pause with keep mode
await engine.pause_generation(mode="keep")
# Wait while paused
await asyncio.sleep(0.5)
# Resume
await engine.resume_generation()
# All requests should complete
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)
assert len(completed_requests) == num_requests
for result in results:
assert result.finished
assert len(result.outputs[0].token_ids) == 10
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any from typing import TYPE_CHECKING, Any
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.weight_transfer.base import ( from vllm.distributed.weight_transfer.base import (
...@@ -22,6 +22,9 @@ from vllm.tasks import SupportedTask ...@@ -22,6 +22,9 @@ from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
if TYPE_CHECKING:
from vllm.v1.engine import PauseMode
class EngineClient(ABC): class EngineClient(ABC):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
...@@ -158,16 +161,22 @@ class EngineClient(ABC): ...@@ -158,16 +161,22 @@ class EngineClient(ABC):
async def pause_generation( async def pause_generation(
self, self,
*, *,
mode: "PauseMode" = "abort",
wait_for_inflight_requests: bool = False, wait_for_inflight_requests: bool = False,
clear_cache: bool = True, clear_cache: bool = True,
) -> None: ) -> None:
"""Pause new generation/encoding requests. """Pause new generation/encoding requests.
Args: Args:
wait_for_inflight_requests: When ``True`` waits for in-flight requests mode: How to handle in-flight requests:
to finish before pausing. When ``False`` (default), aborts in-flight - ``"abort"``: Abort all in-flight requests immediately
requests immediately. and return partial results with "abort" reason (default).
clear_cache: Whether to clear KV and prefix caches after draining. - ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
clear_cache: DEPRECATED. Whether to clear KV and prefix caches
after draining.
""" """
... ...
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
...@@ -14,6 +15,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -14,6 +15,7 @@ from vllm.distributed.weight_transfer.base import (
) )
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.engine import PauseMode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,24 +30,29 @@ router = APIRouter() ...@@ -28,24 +30,29 @@ router = APIRouter()
@router.post("/pause") @router.post("/pause")
async def pause_generation( async def pause_generation(
raw_request: Request, raw_request: Request,
mode: Annotated[PauseMode, Query()] = "abort",
wait_for_inflight_requests: bool = Query(False), wait_for_inflight_requests: bool = Query(False),
clear_cache: bool = Query(True), clear_cache: Annotated[bool, Query()] = True,
) -> JSONResponse: ) -> JSONResponse:
"""Pause generation requests to allow weight updates. """Pause generation requests to allow weight updates.
Args: Args:
wait_for_inflight_requests: When ``True`` waits for in-flight mode: How to handle in-flight requests:
requests to finish before pausing. When ``False`` (default), - ``"abort"``: Abort all in-flight requests immediately (default).
aborts any in-flight requests immediately. - ``"wait"``: Wait for in-flight requests to complete.
clear_cache: Whether to clear KV/prefix caches after draining. - ``"keep"``: Freeze requests in queue; they resume on /resume.
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
draining. Ignored when mode="keep".
""" """
engine = engine_client(raw_request) engine = engine_client(raw_request)
try: try:
await engine.pause_generation( await engine.pause_generation(
wait_for_inflight_requests=wait_for_inflight_requests, mode=mode,
clear_cache=clear_cache, clear_cache=clear_cache,
wait_for_inflight_requests=wait_for_inflight_requests,
) )
return JSONResponse( return JSONResponse(
content={"status": "paused"}, content={"status": "paused"},
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import enum import enum
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any, Literal
import msgspec import msgspec
import numpy as np import numpy as np
...@@ -18,6 +18,12 @@ from vllm.v1.metrics.stats import SchedulerStats ...@@ -18,6 +18,12 @@ from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult from vllm.v1.serial_utils import UtilityResult
# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]
# These are possible values of RequestOutput.finish_reason, # These are possible values of RequestOutput.finish_reason,
# so form part of the external API. # so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
......
...@@ -38,7 +38,7 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_va ...@@ -38,7 +38,7 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_va
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest, PauseMode
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
...@@ -170,6 +170,7 @@ class AsyncLLM(EngineClient): ...@@ -170,6 +170,7 @@ class AsyncLLM(EngineClient):
# Pause / resume state for async RL workflows. # Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition() self._pause_cond = asyncio.Condition()
self._paused = False self._paused = False
self._client_count = client_count
self.output_handler: asyncio.Task | None = None self.output_handler: asyncio.Task | None = None
try: try:
...@@ -728,7 +729,8 @@ class AsyncLLM(EngineClient): ...@@ -728,7 +729,8 @@ class AsyncLLM(EngineClient):
async def pause_generation( async def pause_generation(
self, self,
*, *,
wait_for_inflight_requests: bool = False, mode: PauseMode = "abort",
wait_for_inflight_requests: bool | None = None,
clear_cache: bool = True, clear_cache: bool = True,
) -> None: ) -> None:
""" """
...@@ -737,27 +739,52 @@ class AsyncLLM(EngineClient): ...@@ -737,27 +739,52 @@ class AsyncLLM(EngineClient):
New generation/encoding requests are blocked until resume. New generation/encoding requests are blocked until resume.
Args: Args:
wait_for_inflight_requests: When ``True`` waits for in-flight mode: How to handle in-flight requests:
requests to finish before pausing. When ``False`` (default), - ``"abort"``: Abort all in-flight requests immediately
immediately aborts any in-flight requests. (default).
- ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED: use mode argument.
Whether to wait for in-flight requests to complete before pausing.
clear_cache: Whether to clear KV cache and prefix cache after clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume. draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches). Default is ``True`` (clear caches).
""" """
if wait_for_inflight_requests:
warnings.warn(
"The `wait_for_inflight_requests` parameter in "
"`AsyncLLM.pause_generation()` is deprecated. "
"Please use `mode` argument instead.",
DeprecationWarning,
stacklevel=2,
)
mode = "wait"
if mode == "keep":
# Freeze requests in the scheduler - they will resume on
# resume_generation().
await self.engine_core.pause_scheduler_async()
else:
if self._client_count > 1:
raise NotImplementedError(
"pause_generation is not supported with --api-server-count > 1"
" when mode is not 'keep'"
)
async with self._pause_cond: async with self._pause_cond:
if self._paused: if not self._paused:
return
self._paused = True self._paused = True
if not wait_for_inflight_requests: if mode == "abort":
request_ids = list(self.output_processor.request_states.keys()) request_ids = list(self.output_processor.request_states.keys())
if request_ids: if request_ids:
await self.abort(request_ids, internal=True) await self.abort(request_ids, internal=True)
elif mode == "wait":
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests(): if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain() await self.output_processor.wait_for_requests_to_drain()
else:
raise ValueError(f"Invalid mode: {mode}")
# Clear cache # Clear cache
if clear_cache: if clear_cache:
...@@ -769,6 +796,7 @@ class AsyncLLM(EngineClient): ...@@ -769,6 +796,7 @@ class AsyncLLM(EngineClient):
"""Resume generation after :meth:`pause_generation`.""" """Resume generation after :meth:`pause_generation`."""
async with self._pause_cond: async with self._pause_cond:
await self.engine_core.resume_scheduler_async()
self._paused = False self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests self._pause_cond.notify_all() # Wake up all waiting requests
......
...@@ -209,6 +209,10 @@ class EngineCore: ...@@ -209,6 +209,10 @@ class EngineCore:
self.async_scheduling = vllm_config.scheduler_config.async_scheduling self.async_scheduling = vllm_config.scheduler_config.async_scheduling
self.aborts_queue = queue.Queue[list[str]]() self.aborts_queue = queue.Queue[list[str]]()
# Pause state for "keep" mode - freezes requests in queue.
self._scheduler_paused = False
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
freeze_gc_heap() freeze_gc_heap()
...@@ -322,6 +326,20 @@ class EngineCore: ...@@ -322,6 +326,20 @@ class EngineCore:
# (i.e. client-aborted vs stop criteria met). # (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
def pause_scheduler(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Requests are kept frozen in queue and can be resumed later.
"""
self._scheduler_paused = True
def resume_scheduler(self) -> None:
"""Resume the scheduler after a pause.
Resumes processing of frozen requests in the queue.
"""
self._scheduler_paused = False
@contextmanager @contextmanager
def log_error_detail(self, scheduler_output: SchedulerOutput): def log_error_detail(self, scheduler_output: SchedulerOutput):
"""Execute the model and log detailed info on failure.""" """Execute the model and log detailed info on failure."""
...@@ -375,6 +393,10 @@ class EngineCore: ...@@ -375,6 +393,10 @@ class EngineCore:
was executed. was executed.
""" """
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
# Check for any requests remaining in the scheduler - unfinished, # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch. # or finished and not yet removed from the batch.
if not self.scheduler.has_requests(): if not self.scheduler.has_requests():
...@@ -425,6 +447,10 @@ class EngineCore: ...@@ -425,6 +447,10 @@ class EngineCore:
batch in the job queue is finished. batch in the job queue is finished.
3. Update the scheduler from the output. 3. Update the scheduler from the output.
""" """
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
batch_queue = self.batch_queue batch_queue = self.batch_queue
assert batch_queue is not None assert batch_queue is not None
...@@ -1007,6 +1033,7 @@ class EngineCoreProc(EngineCore): ...@@ -1007,6 +1033,7 @@ class EngineCoreProc(EngineCore):
not self.engines_running not self.engines_running
and not self.scheduler.has_requests() and not self.scheduler.has_requests()
and not self.batch_queue and not self.batch_queue
and not self._scheduler_paused
): ):
if self.input_queue.empty(): if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue. # Drain aborts queue; all aborts are also processed via input_queue.
......
...@@ -105,7 +105,7 @@ class EngineCoreClient(ABC): ...@@ -105,7 +105,7 @@ class EngineCoreClient(ABC):
client_addresses: dict[str, str] | None = None, client_addresses: dict[str, str] | None = None,
client_count: int = 1, client_count: int = 1,
client_index: int = 0, client_index: int = 0,
) -> "MPClient": ) -> "AsyncMPClient":
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
client_args = ( client_args = (
vllm_config, vllm_config,
...@@ -976,6 +976,16 @@ class AsyncMPClient(MPClient): ...@@ -976,6 +976,16 @@ class AsyncMPClient(MPClient):
if request_ids and not self.resources.engine_dead: if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Blocks until the EngineCore acknowledges the pause.
"""
await self.call_utility_async("pause_scheduler")
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
await self.call_utility_async("resume_scheduler")
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self.call_utility_async("profile", is_start) await self.call_utility_async("profile", is_start)
...@@ -1188,6 +1198,18 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1188,6 +1198,18 @@ class DPAsyncMPClient(AsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest): def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine return self.core_engine
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue."""
raise NotImplementedError(
"pause_scheduler_async is not yet supported for data parallel"
)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
raise NotImplementedError(
"resume_scheduler_async is not yet supported for data parallel"
)
class DPLBAsyncMPClient(DPAsyncMPClient): class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel) """Asyncio-compatible client for multi-proc, multi-engine (data parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment