Unverified Commit dbf0da81 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Cleanup engine pause/sleep logic (#34528)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 3bbb2046
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
import asyncio import asyncio
import os import os
import time
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
import pytest import pytest
...@@ -187,24 +189,33 @@ async def test_load( ...@@ -187,24 +189,33 @@ async def test_load(
# ============================================================================= # =============================================================================
# DP Pause/Resume Tests # DP Pause/Resume Tests
# ============================================================================= # =============================================================================
# When expert_parallel=False: uses non-MoE model (DP replicas as separate engines).
# When expert_parallel=True: uses MoE model + EP (DPEngineCoreProc, sync pause path).
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM" DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_MODEL_MOE = "ibm-research/PowerMoE-3b"
DP_PAUSE_PROMPT = "This is a test of data parallel pause" DP_PAUSE_PROMPT = "This is a test of data parallel pause"
def _get_dp_pause_engine_args(expert_parallel: bool) -> AsyncEngineArgs:
"""Engine args for DP pause tests: MoE+EP when expert_parallel else small Llama."""
model = DP_PAUSE_MODEL_MOE if expert_parallel else DP_PAUSE_MODEL
return AsyncEngineArgs(
model=model,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
enable_expert_parallel=expert_parallel,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dp_pause_resume_basic(): @pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_resume_basic(expert_parallel: bool):
"""Pausing from the client (one call) pauses all DP ranks; resume clears it.""" """Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after: with ExitStack() as after:
engine_args = AsyncEngineArgs( engine_args = _get_dp_pause_engine_args(expert_parallel)
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
...@@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic(): ...@@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dp_pause_abort(): @pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_abort(expert_parallel: bool):
"""Pause with abort from one client aborts in-flight requests on all DP ranks.""" """Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after: with ExitStack() as after:
engine_args = AsyncEngineArgs( engine_args = _get_dp_pause_engine_args(expert_parallel)
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
...@@ -286,41 +290,111 @@ async def test_dp_pause_abort(): ...@@ -286,41 +290,111 @@ async def test_dp_pause_abort():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dp_pause_keep_then_resume(): @pytest.mark.parametrize("expert_parallel", [False, True])
"""Pause with keep queues new requests; resume allows them to run.""" async def test_dp_pause_keep_then_resume(expert_parallel: bool):
if current_platform.is_rocm(): """Start generation, pause after a few tokens (keep mode), resume; verify gap."""
pytest.skip("DP pause tests use mp backend only")
pause_duration = 2.0
min_tokens_before_pause = 3
with ExitStack() as after: with ExitStack() as after:
engine_args = AsyncEngineArgs( engine_args = _get_dp_pause_engine_args(expert_parallel)
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
await engine.pause_generation(mode="keep") sampling_params = SamplingParams(max_tokens=15, ignore_eos=True)
assert await engine.is_paused() token_times: list[tuple[int, float]] = []
pause_token_idx = 0
request_done = asyncio.Event()
async def gen(): async def generator_task():
async for out in engine.generate( nonlocal pause_token_idx
request_id="queued-keep", out = None
async for output in engine.generate(
request_id="keep-resume-req",
prompt=DP_PAUSE_PROMPT, prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5), sampling_params=sampling_params,
): ):
pass token_count = len(output.outputs[0].token_ids)
request_done.set() token_times.append((token_count, time.monotonic()))
out = output
return out return out
task = asyncio.create_task(gen()) async def controller_task():
await asyncio.sleep(0.2) nonlocal pause_token_idx
assert not request_done.is_set() while len(token_times) < min_tokens_before_pause:
await asyncio.sleep(0.01)
await engine.pause_generation(mode="keep")
await asyncio.sleep(pause_duration)
pause_token_idx = len(token_times)
await engine.resume_generation()
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
assert final_output is not None and final_output.finished
assert await engine.is_paused() is False
assert pause_token_idx >= min_tokens_before_pause
if pause_token_idx > 0 and pause_token_idx < len(token_times):
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)
@pytest.mark.asyncio
async def test_dp_pause_keep_race_staggered_engines():
"""Race: send pause(keep) to engine 0, then add two requests,
then pause(keep) to engine 1. Ensures no deadlock when pause
requests are staggered and requests arrive in between."""
if DP_SIZE != 2:
pytest.skip("test_dp_pause_keep_race_staggered_engines requires DP_SIZE=2")
with ExitStack() as after:
engine_args = _get_dp_pause_engine_args(expert_parallel=True)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
client = engine.engine_core
original_call_utility = client.call_utility_async
mid_pause_tasks: list[asyncio.Task] = []
async def staggered_pause_keep(method: str, *args) -> Any:
if method != "pause_scheduler" or not args or args[0] != "keep":
return await original_call_utility(method, *args)
# Send pause(keep) to engine 0 first
await client._call_utility_async(
method, *args, engine=client.core_engines[0]
)
# In the middle: send two requests (race window)
sp = SamplingParams(max_tokens=5, ignore_eos=True)
async def consume_gen(req_id: str) -> None:
async for _ in engine.generate(
request_id=req_id,
prompt=DP_PAUSE_PROMPT,
sampling_params=sp,
):
pass
t1 = asyncio.create_task(consume_gen("race-1"))
t2 = asyncio.create_task(consume_gen("race-2"))
mid_pause_tasks.extend([t1, t2])
await asyncio.sleep(3)
# Then send pause(keep) to engine 1
result = await client._call_utility_async(
method, *args, engine=client.core_engines[1]
)
return result
client.call_utility_async = staggered_pause_keep
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
await engine.resume_generation() await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused() assert not await engine.is_paused()
# Let the two requests we sent mid-pause complete
await asyncio.gather(*mid_pause_tasks)
...@@ -280,20 +280,15 @@ def echo_dc_nested( ...@@ -280,20 +280,15 @@ def echo_dc_nested(
def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future: def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
"""Utility that returns a Future completed by a per_step_hook after """Utility that returns a Future completed once the engine is idle
num_wait_loops engine steps (tests deferred utility path). (tests deferred utility path).
""" """
future: Future = Future() future: Future = Future()
remaining = [num_wait_loops]
def _step(engine: EngineCore) -> bool: def idle(engine: EngineCore):
remaining[0] -= 1 future.set_result(value)
if remaining[0] <= 0:
future.set_result(value)
return True # remove hook
return False
self.per_step_hooks.add(_step) self._idle_state_callbacks.append(idle)
return future return future
...@@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async( ...@@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
subprocess_future_echo_patch, subprocess_future_echo_patch,
): ):
"""Test that a utility returning a Future (completed by a per_step_hook """Test that a utility returning a Future completes when the future is done
after N steps) completes when the future is done (engine uses add_done_callback). (engine uses add_done_callback).
""" """
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(EngineCore, "future_echo", future_echo, raising=False) m.setattr(EngineCore, "future_echo", future_echo, raising=False)
......
...@@ -148,7 +148,7 @@ class EngineClient(ABC): ...@@ -148,7 +148,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
"""Sleep the engine""" """Sleep the engine"""
... ...
......
...@@ -87,6 +87,7 @@ from vllm.usage.usage_lib import UsageContext ...@@ -87,6 +87,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.tqdm_utils import maybe_tqdm from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine import PauseMode
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
...@@ -441,8 +442,7 @@ class LLM: ...@@ -441,8 +442,7 @@ class LLM:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
""" """
model_config = self.model_config runner_type = self.model_config.runner_type
runner_type = model_config.runner_type
if runner_type != "generate": if runner_type != "generate":
raise ValueError( raise ValueError(
"LLM.generate() is only supported for generative models. " "LLM.generate() is only supported for generative models. "
...@@ -489,46 +489,22 @@ class LLM: ...@@ -489,46 +489,22 @@ class LLM:
Returns: Returns:
A list of request IDs for the enqueued requests. A list of request IDs for the enqueued requests.
""" """
model_config = self.model_config runner_type = self.model_config.runner_type
runner_type = model_config.runner_type
if runner_type != "generate": if runner_type != "generate":
raise ValueError("LLM.enqueue() is only supported for generative models.") raise ValueError("LLM.enqueue() is only supported for generative models.")
if sampling_params is None: if sampling_params is None:
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
# Use the same preprocessing as _run_completion return self._add_completion_requests(
seq_prompts = prompt_to_seq(prompts) prompts=prompts,
seq_params = self._params_to_seq(sampling_params, len(seq_prompts)) params=sampling_params,
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts)) use_tqdm=use_tqdm,
seq_tok_kwargs = [ lora_request=lora_request,
merge_kwargs( priority=priority,
tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
request_ids = self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
priorities=seq_priority,
) )
return request_ids
@overload @overload
def wait_for_completion( def wait_for_completion(
self, self,
...@@ -1659,7 +1635,7 @@ class LLM: ...@@ -1659,7 +1635,7 @@ class LLM:
reset_running_requests, reset_connector reset_running_requests, reset_connector
) )
def sleep(self, level: int = 1): def sleep(self, level: int = 1, mode: PauseMode = "abort"):
""" """
Put the engine to sleep. The engine should not process any requests. Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed The caller should guarantee that no requests are being processed
...@@ -1679,10 +1655,10 @@ class LLM: ...@@ -1679,10 +1655,10 @@ class LLM:
a different model or update the model, where a different model or update the model, where
previous model weights are not needed. It reduces previous model weights are not needed. It reduces
CPU memory pressure. CPU memory pressure.
mode: How to handle any existing requests, can be "abort", "wait",
or "keep".
""" """
if level > 0: self.llm_engine.sleep(level=level, mode=mode)
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self, tags: list[str] | None = None): def wake_up(self, tags: list[str] | None = None):
""" """
...@@ -1759,19 +1735,18 @@ class LLM: ...@@ -1759,19 +1735,18 @@ class LLM:
return [0] * num_requests return [0] * num_requests
def _run_completion( def _add_completion_requests(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
params: SamplingParams params: SamplingParams
| PoolingParams | PoolingParams
| Sequence[SamplingParams | PoolingParams], | Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
): ) -> list[str]:
seq_prompts = prompt_to_seq(prompts) seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts)) seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts)) seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
...@@ -1784,25 +1759,44 @@ class LLM: ...@@ -1784,25 +1759,44 @@ class LLM:
] ]
seq_priority = self._priority_to_seq(priority, len(prompts)) seq_priority = self._priority_to_seq(priority, len(prompts))
return self._render_and_run_requests( return self._render_and_add_requests(
prompts=( prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs) self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip( for prompt, tok_kwargs in zip(
maybe_tqdm( maybe_tqdm(
seq_prompts, seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
use_tqdm=use_tqdm,
desc="Rendering prompts",
), ),
seq_tok_kwargs, seq_tok_kwargs,
) )
), ),
params=seq_params, params=seq_params,
output_type=output_type,
use_tqdm=use_tqdm,
lora_requests=seq_lora_requests, lora_requests=seq_lora_requests,
priorities=seq_priority, priorities=seq_priority,
) )
def _run_completion(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
self._add_completion_requests(
prompts=prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)
def _run_chat( def _run_chat(
self, self,
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
......
...@@ -23,7 +23,8 @@ router = APIRouter() ...@@ -23,7 +23,8 @@ router = APIRouter()
async def sleep(raw_request: Request): async def sleep(raw_request: Request):
# get POST params # get POST params
level = raw_request.query_params.get("level", "1") level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level)) mode = raw_request.query_params.get("mode", "abort")
await engine_client(raw_request).sleep(int(level), mode)
# FIXME: in v0 with frontend multiprocessing, the sleep command # FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
return Response(status_code=200) return Response(status_code=200)
......
...@@ -753,6 +753,13 @@ class AsyncLLM(EngineClient): ...@@ -753,6 +753,13 @@ class AsyncLLM(EngineClient):
) )
mode = "wait" mode = "wait"
await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache) await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
# Small sleep to help ensure that final outputs from any in-flight requests are
# returned prior to this method returning. These outputs come out of the engine
# prior to the wait-for-idle completion event, but involve additional async
# tasks in output processing.
# Note that this is not required for correctness, just more intuitive ordering
# of events from caller's pov.
await asyncio.sleep(0.02)
async def resume_generation(self) -> None: async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`.""" """Resume generation after :meth:`pause_generation`."""
...@@ -890,10 +897,8 @@ class AsyncLLM(EngineClient): ...@@ -890,10 +897,8 @@ class AsyncLLM(EngineClient):
async def reset_encoder_cache(self) -> None: async def reset_encoder_cache(self) -> None:
await self.engine_core.reset_encoder_cache_async() await self.engine_core.reset_encoder_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
if level > 0: await self.engine_core.sleep_async(level, mode)
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)
if self.logger_manager is not None: if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level) self.logger_manager.record_sleep_state(1, level)
......
...@@ -9,6 +9,7 @@ from collections import defaultdict, deque ...@@ -9,6 +9,7 @@ from collections import defaultdict, deque
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import Future from concurrent.futures import Future
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from functools import partial
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
...@@ -211,7 +212,7 @@ class EngineCore: ...@@ -211,7 +212,7 @@ class EngineCore:
self.aborts_queue = queue.Queue[list[str]]() self.aborts_queue = queue.Queue[list[str]]()
self.per_step_hooks: set[Callable] = set() self._idle_state_callbacks: list[Callable] = []
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
...@@ -592,21 +593,51 @@ class EngineCore: ...@@ -592,21 +593,51 @@ class EngineCore:
# Reset the GPU model runner's encoder cache (physical storage) # Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache() self.model_executor.reset_encoder_cache()
def _reset_caches(self, reset_running_requests=True) -> None:
self.reset_prefix_cache(reset_running_requests=reset_running_requests)
self.reset_mm_cache()
self.reset_encoder_cache()
def pause_scheduler( def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future[Any] | None: ) -> Future | None:
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc.""" """Pause generation; behavior depends on mode.
All pause modes queue new adds -- "abort" and "keep" skip step();
"wait" allows step() so in-flight requests can drain.
- ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
outputs to be sent (when running with output_queue), optionally
clear caches, then complete the returned Future.
- ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
optionally clear caches, then complete the returned Future.
- ``keep``: Set PAUSED_ALL; return a Future that completes when the
output queue is empty.
"""
if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}")
if mode == "wait":
raise ValueError("'wait' mode can't be used in inproc-engine mode")
if mode == "abort":
self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state)
if clear_cache:
self._reset_caches()
return None return None
def resume_scheduler(self) -> None: def resume_scheduler(self) -> None:
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc.""" """Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool: def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state. False in base EngineCore """Return whether the scheduler is in any pause state."""
and overridden in EngineCoreProc.""" return self.scheduler.pause_state != PauseState.UNPAUSED
return False
def sleep(self, level: int = 1): def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
"""Put the engine to sleep at the specified level. """Put the engine to sleep at the specified level.
Args: Args:
...@@ -615,13 +646,34 @@ class EngineCore: ...@@ -615,13 +646,34 @@ class EngineCore:
but not processed. No GPU memory changes. but not processed. No GPU memory changes.
- Level 1: Offload model weights to CPU, discard KV cache. - Level 1: Offload model weights to CPU, discard KV cache.
- Level 2: Discard all GPU memory. - Level 2: Discard all GPU memory.
mode: Pause mode - how to deal with any existing requests, see
documentation of pause_scheduler method.
""" """
if level == 0:
# Level 0: Just pause scheduling, don't touch GPU # Pause scheduler before sleeping.
self.pause_scheduler() clear_prefix_cache = level >= 1
else: pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
# Level 1+: Delegate to executor for GPU memory management if level < 1:
self.model_executor.sleep(level) return pause_future
# Level 1+: Delegate to executor for GPU memory management
model_executor = self.model_executor
if pause_future is None:
model_executor.sleep(level)
return None
future = Future[Any]()
def pause_complete(f: Future):
try:
f.result() # propagate any exception
future.set_result(model_executor.sleep(level))
except Exception as e:
future.set_exception(e)
logger.info("Waiting for in-flight requests to complete before sleeping...")
pause_future.add_done_callback(pause_complete)
return future
def wake_up(self, tags: list[str] | None = None): def wake_up(self, tags: list[str] | None = None):
"""Wake up the engine from sleep. """Wake up the engine from sleep.
...@@ -630,17 +682,15 @@ class EngineCore: ...@@ -630,17 +682,15 @@ class EngineCore:
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up. tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
""" """
if tags is not None and "scheduling" in tags: if tags is not None and "scheduling" in tags:
# Level 0 wake up: Resume scheduling # Remove "scheduling" from tags if there are other tags to process.
self.resume_scheduler() tags = [t for t in tags if t != "scheduling"]
# Remove "scheduling" from tags if there are other tags to process
remaining_tags = [t for t in tags if t != "scheduling"] if tags is None or tags:
if remaining_tags:
self.model_executor.wake_up(remaining_tags)
else:
# Full wake up
self.resume_scheduler()
self.model_executor.wake_up(tags) self.model_executor.wake_up(tags)
# Resume scheduling (applies to all levels)
self.resume_scheduler()
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
"""Check if engine is sleeping at any level.""" """Check if engine is sleeping at any level."""
return self.is_scheduler_paused() or self.model_executor.is_sleeping return self.is_scheduler_paused() or self.model_executor.is_sleeping
...@@ -1038,6 +1088,14 @@ class EngineCoreProc(EngineCore): ...@@ -1038,6 +1088,14 @@ class EngineCoreProc(EngineCore):
def _init_data_parallel(self, vllm_config: VllmConfig): def _init_data_parallel(self, vllm_config: VllmConfig):
pass pass
def has_work(self) -> bool:
"""Returns true if the engine should be stepped."""
return (
self.engines_running
or self.scheduler.has_requests()
or bool(self.batch_queue)
)
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
...@@ -1047,19 +1105,14 @@ class EngineCoreProc(EngineCore): ...@@ -1047,19 +1105,14 @@ class EngineCoreProc(EngineCore):
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core and return the outputs. # 2) Step the engine core and return the outputs.
self._process_engine_step() self._process_engine_step()
# 3) Run any per-step hooks.
self._process_per_step_hooks()
def _process_input_queue(self): def _process_input_queue(self):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
waited = False waited = False
while ( while not self.has_work():
not self.engines_running # Notify callbacks waiting for engine to become idle.
and not self.scheduler.has_requests() self._notify_idle_state_callbacks()
and not self.batch_queue
and not self.per_step_hooks
):
if self.input_queue.empty(): if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue. # Drain aborts queue; all aborts are also processed via input_queue.
with self.aborts_queue.mutex: with self.aborts_queue.mutex:
...@@ -1098,12 +1151,10 @@ class EngineCoreProc(EngineCore): ...@@ -1098,12 +1151,10 @@ class EngineCoreProc(EngineCore):
return model_executed return model_executed
def _process_per_step_hooks(self) -> None: def _notify_idle_state_callbacks(self) -> None:
if self.per_step_hooks: while self._idle_state_callbacks:
for hook in list(self.per_step_hooks): callback = self._idle_state_callbacks.pop()
finished = hook(self) callback(self)
if finished:
self.per_step_hooks.discard(hook)
def _handle_client_request( def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any self, request_type: EngineCoreRequestType, request: Any
...@@ -1377,19 +1428,10 @@ class EngineCoreProc(EngineCore): ...@@ -1377,19 +1428,10 @@ class EngineCoreProc(EngineCore):
if mode not in ("keep", "abort", "wait"): if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}") raise ValueError(f"Invalid pause mode: {mode}")
future: Future[Any] = Future() def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
def wait_until_idle(engine: "EngineCoreProc") -> bool:
scheduler = engine.scheduler
out_queue = engine.output_queue
if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
return False
if clear_cache: if clear_cache:
engine.reset_prefix_cache(reset_running_requests=True) engine._reset_caches()
engine.reset_mm_cache()
engine.reset_encoder_cache()
future.set_result(None) future.set_result(None)
return True
if mode == "abort": if mode == "abort":
aborted_reqs = self.scheduler.finish_requests( aborted_reqs = self.scheduler.finish_requests(
...@@ -1399,12 +1441,17 @@ class EngineCoreProc(EngineCore): ...@@ -1399,12 +1441,17 @@ class EngineCoreProc(EngineCore):
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state) self.scheduler.set_pause_state(pause_state)
if not wait_until_idle(self): if not self.has_work():
self.per_step_hooks.add(wait_until_idle) if clear_cache:
return future self._reset_caches()
return None return None
future = Future[Any]()
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
return future
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None: def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
# TODO(nick) this will be moved inside the scheduler
if aborted_reqs: if aborted_reqs:
# Map client_index to list of request_ids that belong to that client. # Map client_index to list of request_ids that belong to that client.
by_client = defaultdict[int, set[str]](set) by_client = defaultdict[int, set[str]](set)
...@@ -1418,14 +1465,6 @@ class EngineCoreProc(EngineCore): ...@@ -1418,14 +1465,6 @@ class EngineCoreProc(EngineCore):
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco)) self.output_queue.put_nowait((client_index, eco))
def resume_scheduler(self) -> None:
"""Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state."""
return self.scheduler.pause_state != PauseState.UNPAUSED
class DPEngineCoreProc(EngineCoreProc): class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process """ZMQ-wrapper for running EngineCore in background process
...@@ -1481,6 +1520,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1481,6 +1520,7 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group(dp_group) stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: Request, request_wave: int = 0): def add_request(self, request: Request, request_wave: int = 0):
super().add_request(request, request_wave)
if self.has_coordinator and request_wave != self.current_wave: if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave: if request_wave > self.current_wave:
self.current_wave = request_wave self.current_wave = request_wave
...@@ -1491,7 +1531,13 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1491,7 +1531,13 @@ class DPEngineCoreProc(EngineCoreProc):
(-1, EngineCoreOutputs(start_wave=self.current_wave)) (-1, EngineCoreOutputs(start_wave=self.current_wave))
) )
super().add_request(request, request_wave) def resume_scheduler(self):
super().resume_scheduler()
if not self.engines_running and self.scheduler.has_unfinished_requests():
# Wake up other DP engines.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)
def _handle_client_request( def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any self, request_type: EngineCoreRequestType, request: Any
...@@ -1532,8 +1578,8 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1532,8 +1578,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core. # 2) Step the engine core.
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed: if not executed:
if not local_unfinished_reqs and not self.engines_running: if not local_unfinished_reqs and not self.engines_running:
# All engines are idle. # All engines are idle.
......
...@@ -150,7 +150,7 @@ class EngineCoreClient(ABC): ...@@ -150,7 +150,7 @@ class EngineCoreClient(ABC):
def reset_encoder_cache(self) -> None: def reset_encoder_cache(self) -> None:
raise NotImplementedError raise NotImplementedError
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError raise NotImplementedError
def wake_up(self, tags: list[str] | None = None) -> None: def wake_up(self, tags: list[str] | None = None) -> None:
...@@ -227,7 +227,7 @@ class EngineCoreClient(ABC): ...@@ -227,7 +227,7 @@ class EngineCoreClient(ABC):
async def reset_encoder_cache_async(self) -> None: async def reset_encoder_cache_async(self) -> None:
raise NotImplementedError raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError raise NotImplementedError
async def wake_up_async(self, tags: list[str] | None = None) -> None: async def wake_up_async(self, tags: list[str] | None = None) -> None:
...@@ -314,8 +314,11 @@ class InprocClient(EngineCoreClient): ...@@ -314,8 +314,11 @@ class InprocClient(EngineCoreClient):
def reset_encoder_cache(self) -> None: def reset_encoder_cache(self) -> None:
self.engine_core.reset_encoder_cache() self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
self.engine_core.sleep(level) if mode == "wait":
raise ValueError("'wait' pause mode is not supported in inproc-engine mode")
result = self.engine_core.sleep(level, mode)
assert result is None
def wake_up(self, tags: list[str] | None = None) -> None: def wake_up(self, tags: list[str] | None = None) -> None:
self.engine_core.wake_up(tags) self.engine_core.wake_up(tags)
...@@ -796,8 +799,8 @@ class SyncMPClient(MPClient): ...@@ -796,8 +799,8 @@ class SyncMPClient(MPClient):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.call_utility("pin_lora", lora_id) return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
self.call_utility("sleep", level) self.call_utility("sleep", level, mode)
def wake_up(self, tags: list[str] | None = None) -> None: def wake_up(self, tags: list[str] | None = None) -> None:
self.call_utility("wake_up", tags) self.call_utility("wake_up", tags)
...@@ -1009,8 +1012,8 @@ class AsyncMPClient(MPClient): ...@@ -1009,8 +1012,8 @@ class AsyncMPClient(MPClient):
async def reset_encoder_cache_async(self) -> None: async def reset_encoder_cache_async(self) -> None:
await self.call_utility_async("reset_encoder_cache") await self.call_utility_async("reset_encoder_cache")
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
await self.call_utility_async("sleep", level) await self.call_utility_async("sleep", level, mode)
async def wake_up_async(self, tags: list[str] | None = None) -> None: async def wake_up_async(self, tags: list[str] | None = None) -> None:
await self.call_utility_async("wake_up", tags) await self.call_utility_async("wake_up", tags)
......
...@@ -28,7 +28,7 @@ from vllm.tasks import SupportedTask ...@@ -28,7 +28,7 @@ from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest, PauseMode
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
...@@ -355,8 +355,8 @@ class LLMEngine: ...@@ -355,8 +355,8 @@ class LLMEngine:
""" """
self.engine_core.reset_encoder_cache() self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1, mode: PauseMode = "abort"):
self.engine_core.sleep(level) self.engine_core.sleep(level, mode)
if self.logger_manager is not None: if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level) self.logger_manager.record_sleep_state(1, level)
......
...@@ -429,8 +429,6 @@ class OutputProcessor: ...@@ -429,8 +429,6 @@ class OutputProcessor:
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list) self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats) self.lora_states = LoRARequestStates(log_stats)
self.tracing_enabled = tracing_enabled self.tracing_enabled = tracing_enabled
self._requests_drained = asyncio.Event()
self._requests_drained.set()
def get_num_unfinished_requests(self): def get_num_unfinished_requests(self):
return len(self.request_states) return len(self.request_states)
...@@ -438,11 +436,6 @@ class OutputProcessor: ...@@ -438,11 +436,6 @@ class OutputProcessor:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0 return len(self.request_states) > 0
async def wait_for_requests_to_drain(self) -> None:
if not self.request_states:
return
await self._requests_drained.wait()
def propagate_error(self, e: Exception): def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks.""" """Propagate error to all generate() tasks."""
...@@ -510,8 +503,6 @@ class OutputProcessor: ...@@ -510,8 +503,6 @@ class OutputProcessor:
child_reqs = self.abort_requests(child_reqs, internal=True) child_reqs = self.abort_requests(child_reqs, internal=True)
request_ids_to_abort.extend(child_reqs) request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None) self.parent_requests.pop(request_id, None)
if not self.request_states:
self._requests_drained.set()
return request_ids_to_abort return request_ids_to_abort
def add_request( def add_request(
...@@ -538,8 +529,6 @@ class OutputProcessor: ...@@ -538,8 +529,6 @@ class OutputProcessor:
log_stats=self.log_stats, log_stats=self.log_stats,
stream_interval=self.stream_interval, stream_interval=self.stream_interval,
) )
if self._requests_drained.is_set():
self._requests_drained.clear()
self.request_states[request_id] = req_state self.request_states[request_id] = req_state
if parent_req: if parent_req:
self.parent_requests[parent_req.request_id] = parent_req self.parent_requests[parent_req.request_id] = parent_req
...@@ -706,9 +695,6 @@ class OutputProcessor: ...@@ -706,9 +695,6 @@ class OutputProcessor:
if parent_req and not parent_req.child_requests: if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None) self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None): def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats) self.lora_states.update_scheduler_stats(scheduler_stats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment