Unverified Commit dbf0da81 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Cleanup engine pause/sleep logic (#34528)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 3bbb2046
......@@ -3,8 +3,10 @@
import asyncio
import os
import time
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any
import pytest
......@@ -187,24 +189,33 @@ async def test_load(
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
# When expert_parallel=False: uses non-MoE model (DP replicas as separate engines).
# When expert_parallel=True: uses MoE model + EP (DPEngineCoreProc, sync pause path).
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_MODEL_MOE = "ibm-research/PowerMoE-3b"
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
@pytest.mark.asyncio
async def test_dp_pause_resume_basic():
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
def _get_dp_pause_engine_args(expert_parallel: bool) -> AsyncEngineArgs:
"""Engine args for DP pause tests: MoE+EP when expert_parallel else small Llama."""
model = DP_PAUSE_MODEL_MOE if expert_parallel else DP_PAUSE_MODEL
return AsyncEngineArgs(
model=model,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
enable_expert_parallel=expert_parallel,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_resume_basic(expert_parallel: bool):
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
with ExitStack() as after:
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
......@@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic():
@pytest.mark.asyncio
async def test_dp_pause_abort():
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_abort(expert_parallel: bool):
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
......@@ -286,41 +290,111 @@ async def test_dp_pause_abort():
@pytest.mark.asyncio
async def test_dp_pause_keep_then_resume():
"""Pause with keep queues new requests; resume allows them to run."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_keep_then_resume(expert_parallel: bool):
"""Start generation, pause after a few tokens (keep mode), resume; verify gap."""
pause_duration = 2.0
min_tokens_before_pause = 3
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=15, ignore_eos=True)
token_times: list[tuple[int, float]] = []
pause_token_idx = 0
async def generator_task():
nonlocal pause_token_idx
out = None
async for output in engine.generate(
request_id="keep-resume-req",
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
out = output
return out
async def controller_task():
nonlocal pause_token_idx
while len(token_times) < min_tokens_before_pause:
await asyncio.sleep(0.01)
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
await asyncio.sleep(pause_duration)
pause_token_idx = len(token_times)
await engine.resume_generation()
request_done = asyncio.Event()
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
async def gen():
async for out in engine.generate(
request_id="queued-keep",
assert final_output is not None and final_output.finished
assert await engine.is_paused() is False
assert pause_token_idx >= min_tokens_before_pause
if pause_token_idx > 0 and pause_token_idx < len(token_times):
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)
@pytest.mark.asyncio
async def test_dp_pause_keep_race_staggered_engines():
"""Race: send pause(keep) to engine 0, then add two requests,
then pause(keep) to engine 1. Ensures no deadlock when pause
requests are staggered and requests arrive in between."""
if DP_SIZE != 2:
pytest.skip("test_dp_pause_keep_race_staggered_engines requires DP_SIZE=2")
with ExitStack() as after:
engine_args = _get_dp_pause_engine_args(expert_parallel=True)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
client = engine.engine_core
original_call_utility = client.call_utility_async
mid_pause_tasks: list[asyncio.Task] = []
async def staggered_pause_keep(method: str, *args) -> Any:
if method != "pause_scheduler" or not args or args[0] != "keep":
return await original_call_utility(method, *args)
# Send pause(keep) to engine 0 first
await client._call_utility_async(
method, *args, engine=client.core_engines[0]
)
# In the middle: send two requests (race window)
sp = SamplingParams(max_tokens=5, ignore_eos=True)
async def consume_gen(req_id: str) -> None:
async for _ in engine.generate(
request_id=req_id,
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
sampling_params=sp,
):
pass
request_done.set()
return out
task = asyncio.create_task(gen())
await asyncio.sleep(0.2)
assert not request_done.is_set()
t1 = asyncio.create_task(consume_gen("race-1"))
t2 = asyncio.create_task(consume_gen("race-2"))
mid_pause_tasks.extend([t1, t2])
await asyncio.sleep(3)
# Then send pause(keep) to engine 1
result = await client._call_utility_async(
method, *args, engine=client.core_engines[1]
)
return result
client.call_utility_async = staggered_pause_keep
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused()
# Let the two requests we sent mid-pause complete
await asyncio.gather(*mid_pause_tasks)
......@@ -280,20 +280,15 @@ def echo_dc_nested(
def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
"""Utility that returns a Future completed by a per_step_hook after
num_wait_loops engine steps (tests deferred utility path).
"""Utility that returns a Future completed once the engine is idle
(tests deferred utility path).
"""
future: Future = Future()
remaining = [num_wait_loops]
def _step(engine: EngineCore) -> bool:
remaining[0] -= 1
if remaining[0] <= 0:
def idle(engine: EngineCore):
future.set_result(value)
return True # remove hook
return False
self.per_step_hooks.add(_step)
self._idle_state_callbacks.append(idle)
return future
......@@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async(
monkeypatch: pytest.MonkeyPatch,
subprocess_future_echo_patch,
):
"""Test that a utility returning a Future (completed by a per_step_hook
after N steps) completes when the future is done (engine uses add_done_callback).
"""Test that a utility returning a Future completes when the future is done
(engine uses add_done_callback).
"""
with monkeypatch.context() as m:
m.setattr(EngineCore, "future_echo", future_echo, raising=False)
......
......@@ -148,7 +148,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def sleep(self, level: int = 1) -> None:
async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
"""Sleep the engine"""
...
......
......@@ -87,6 +87,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine import PauseMode
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
......@@ -441,8 +442,7 @@ class LLM:
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
"""
model_config = self.model_config
runner_type = model_config.runner_type
runner_type = self.model_config.runner_type
if runner_type != "generate":
raise ValueError(
"LLM.generate() is only supported for generative models. "
......@@ -489,46 +489,22 @@ class LLM:
Returns:
A list of request IDs for the enqueued requests.
"""
model_config = self.model_config
runner_type = model_config.runner_type
runner_type = self.model_config.runner_type
if runner_type != "generate":
raise ValueError("LLM.enqueue() is only supported for generative models.")
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
# Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
request_ids = self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
return self._add_completion_requests(
prompts=prompts,
params=sampling_params,
use_tqdm=use_tqdm,
desc="Rendering prompts",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
priorities=seq_priority,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
return request_ids
@overload
def wait_for_completion(
self,
......@@ -1659,7 +1635,7 @@ class LLM:
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
def sleep(self, level: int = 1, mode: PauseMode = "abort"):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
......@@ -1679,10 +1655,10 @@ class LLM:
a different model or update the model, where
previous model weights are not needed. It reduces
CPU memory pressure.
mode: How to handle any existing requests, can be "abort", "wait",
or "keep".
"""
if level > 0:
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
self.llm_engine.sleep(level=level, mode=mode)
def wake_up(self, tags: list[str] | None = None):
"""
......@@ -1759,19 +1735,18 @@ class LLM:
return [0] * num_requests
def _run_completion(
def _add_completion_requests(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
) -> list[str]:
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
......@@ -1784,25 +1759,44 @@ class LLM:
]
seq_priority = self._priority_to_seq(priority, len(prompts))
return self._render_and_run_requests(
return self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
),
seq_tok_kwargs,
)
),
params=seq_params,
output_type=output_type,
use_tqdm=use_tqdm,
lora_requests=seq_lora_requests,
priorities=seq_priority,
)
def _run_completion(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
self._add_completion_requests(
prompts=prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
......
......@@ -23,7 +23,8 @@ router = APIRouter()
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level))
mode = raw_request.query_params.get("mode", "abort")
await engine_client(raw_request).sleep(int(level), mode)
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
......
......@@ -753,6 +753,13 @@ class AsyncLLM(EngineClient):
)
mode = "wait"
await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
# Small sleep to help ensure that final outputs from any in-flight requests are
# returned prior to this method returning. These outputs come out of the engine
# prior to the wait-for-idle completion event, but involve additional async
# tasks in output processing.
# Note that this is not required for correctness, just more intuitive ordering
# of events from caller's pov.
await asyncio.sleep(0.02)
async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`."""
......@@ -890,10 +897,8 @@ class AsyncLLM(EngineClient):
async def reset_encoder_cache(self) -> None:
await self.engine_core.reset_encoder_cache_async()
async def sleep(self, level: int = 1) -> None:
if level > 0:
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)
async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
await self.engine_core.sleep_async(level, mode)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
......
......@@ -9,6 +9,7 @@ from collections import defaultdict, deque
from collections.abc import Callable, Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from functools import partial
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, TypeVar, cast
......@@ -211,7 +212,7 @@ class EngineCore:
self.aborts_queue = queue.Queue[list[str]]()
self.per_step_hooks: set[Callable] = set()
self._idle_state_callbacks: list[Callable] = []
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
......@@ -592,21 +593,51 @@ class EngineCore:
# Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache()
def _reset_caches(self, reset_running_requests=True) -> None:
self.reset_prefix_cache(reset_running_requests=reset_running_requests)
self.reset_mm_cache()
self.reset_encoder_cache()
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> Future[Any] | None:
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
) -> Future | None:
"""Pause generation; behavior depends on mode.
All pause modes queue new adds -- "abort" and "keep" skip step();
"wait" allows step() so in-flight requests can drain.
- ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
outputs to be sent (when running with output_queue), optionally
clear caches, then complete the returned Future.
- ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
optionally clear caches, then complete the returned Future.
- ``keep``: Set PAUSED_ALL; return a Future that completes when the
output queue is empty.
"""
if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}")
if mode == "wait":
raise ValueError("'wait' mode can't be used in inproc-engine mode")
if mode == "abort":
self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state)
if clear_cache:
self._reset_caches()
return None
def resume_scheduler(self) -> None:
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
"""Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state. False in base EngineCore
and overridden in EngineCoreProc."""
return False
"""Return whether the scheduler is in any pause state."""
return self.scheduler.pause_state != PauseState.UNPAUSED
def sleep(self, level: int = 1):
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
"""Put the engine to sleep at the specified level.
Args:
......@@ -615,13 +646,34 @@ class EngineCore:
but not processed. No GPU memory changes.
- Level 1: Offload model weights to CPU, discard KV cache.
- Level 2: Discard all GPU memory.
mode: Pause mode - how to deal with any existing requests, see
documentation of pause_scheduler method.
"""
if level == 0:
# Level 0: Just pause scheduling, don't touch GPU
self.pause_scheduler()
else:
# Pause scheduler before sleeping.
clear_prefix_cache = level >= 1
pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
if level < 1:
return pause_future
# Level 1+: Delegate to executor for GPU memory management
self.model_executor.sleep(level)
model_executor = self.model_executor
if pause_future is None:
model_executor.sleep(level)
return None
future = Future[Any]()
def pause_complete(f: Future):
try:
f.result() # propagate any exception
future.set_result(model_executor.sleep(level))
except Exception as e:
future.set_exception(e)
logger.info("Waiting for in-flight requests to complete before sleeping...")
pause_future.add_done_callback(pause_complete)
return future
def wake_up(self, tags: list[str] | None = None):
"""Wake up the engine from sleep.
......@@ -630,17 +682,15 @@ class EngineCore:
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
"""
if tags is not None and "scheduling" in tags:
# Level 0 wake up: Resume scheduling
self.resume_scheduler()
# Remove "scheduling" from tags if there are other tags to process
remaining_tags = [t for t in tags if t != "scheduling"]
if remaining_tags:
self.model_executor.wake_up(remaining_tags)
else:
# Full wake up
self.resume_scheduler()
# Remove "scheduling" from tags if there are other tags to process.
tags = [t for t in tags if t != "scheduling"]
if tags is None or tags:
self.model_executor.wake_up(tags)
# Resume scheduling (applies to all levels)
self.resume_scheduler()
def is_sleeping(self) -> bool:
"""Check if engine is sleeping at any level."""
return self.is_scheduler_paused() or self.model_executor.is_sleeping
......@@ -1038,6 +1088,14 @@ class EngineCoreProc(EngineCore):
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def has_work(self) -> bool:
"""Returns true if the engine should be stepped."""
return (
self.engines_running
or self.scheduler.has_requests()
or bool(self.batch_queue)
)
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
......@@ -1047,19 +1105,14 @@ class EngineCoreProc(EngineCore):
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
# 3) Run any per-step hooks.
self._process_per_step_hooks()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while (
not self.engines_running
and not self.scheduler.has_requests()
and not self.batch_queue
and not self.per_step_hooks
):
while not self.has_work():
# Notify callbacks waiting for engine to become idle.
self._notify_idle_state_callbacks()
if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue.
with self.aborts_queue.mutex:
......@@ -1098,12 +1151,10 @@ class EngineCoreProc(EngineCore):
return model_executed
def _process_per_step_hooks(self) -> None:
if self.per_step_hooks:
for hook in list(self.per_step_hooks):
finished = hook(self)
if finished:
self.per_step_hooks.discard(hook)
def _notify_idle_state_callbacks(self) -> None:
while self._idle_state_callbacks:
callback = self._idle_state_callbacks.pop()
callback(self)
def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any
......@@ -1377,19 +1428,10 @@ class EngineCoreProc(EngineCore):
if mode not in ("keep", "abort", "wait"):
raise ValueError(f"Invalid pause mode: {mode}")
future: Future[Any] = Future()
def wait_until_idle(engine: "EngineCoreProc") -> bool:
scheduler = engine.scheduler
out_queue = engine.output_queue
if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
return False
def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
if clear_cache:
engine.reset_prefix_cache(reset_running_requests=True)
engine.reset_mm_cache()
engine.reset_encoder_cache()
engine._reset_caches()
future.set_result(None)
return True
if mode == "abort":
aborted_reqs = self.scheduler.finish_requests(
......@@ -1399,12 +1441,17 @@ class EngineCoreProc(EngineCore):
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
self.scheduler.set_pause_state(pause_state)
if not wait_until_idle(self):
self.per_step_hooks.add(wait_until_idle)
return future
if not self.has_work():
if clear_cache:
self._reset_caches()
return None
future = Future[Any]()
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
return future
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
# TODO(nick) this will be moved inside the scheduler
if aborted_reqs:
# Map client_index to list of request_ids that belong to that client.
by_client = defaultdict[int, set[str]](set)
......@@ -1418,14 +1465,6 @@ class EngineCoreProc(EngineCore):
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco))
def resume_scheduler(self) -> None:
"""Resume the scheduler and flush any requests queued while paused."""
self.scheduler.set_pause_state(PauseState.UNPAUSED)
def is_scheduler_paused(self) -> bool:
"""Return whether the scheduler is in any pause state."""
return self.scheduler.pause_state != PauseState.UNPAUSED
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
......@@ -1481,6 +1520,7 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: Request, request_wave: int = 0):
super().add_request(request, request_wave)
if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave:
self.current_wave = request_wave
......@@ -1491,7 +1531,13 @@ class DPEngineCoreProc(EngineCoreProc):
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)
super().add_request(request, request_wave)
def resume_scheduler(self):
super().resume_scheduler()
if not self.engines_running and self.scheduler.has_unfinished_requests():
# Wake up other DP engines.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)
def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any
......@@ -1532,8 +1578,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
......
......@@ -150,7 +150,7 @@ class EngineCoreClient(ABC):
def reset_encoder_cache(self) -> None:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError
def wake_up(self, tags: list[str] | None = None) -> None:
......@@ -227,7 +227,7 @@ class EngineCoreClient(ABC):
async def reset_encoder_cache_async(self) -> None:
raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None:
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError
async def wake_up_async(self, tags: list[str] | None = None) -> None:
......@@ -314,8 +314,11 @@ class InprocClient(EngineCoreClient):
def reset_encoder_cache(self) -> None:
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
if mode == "wait":
raise ValueError("'wait' pause mode is not supported in inproc-engine mode")
result = self.engine_core.sleep(level, mode)
assert result is None
def wake_up(self, tags: list[str] | None = None) -> None:
self.engine_core.wake_up(tags)
......@@ -796,8 +799,8 @@ class SyncMPClient(MPClient):
def pin_lora(self, lora_id: int) -> bool:
return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None:
self.call_utility("sleep", level)
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
self.call_utility("sleep", level, mode)
def wake_up(self, tags: list[str] | None = None) -> None:
self.call_utility("wake_up", tags)
......@@ -1009,8 +1012,8 @@ class AsyncMPClient(MPClient):
async def reset_encoder_cache_async(self) -> None:
await self.call_utility_async("reset_encoder_cache")
async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level)
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
await self.call_utility_async("sleep", level, mode)
async def wake_up_async(self, tags: list[str] | None = None) -> None:
await self.call_utility_async("wake_up", tags)
......
......@@ -28,7 +28,7 @@ from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine import EngineCoreRequest, PauseMode
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor
......@@ -355,8 +355,8 @@ class LLMEngine:
"""
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def sleep(self, level: int = 1, mode: PauseMode = "abort"):
self.engine_core.sleep(level, mode)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
......
......@@ -429,8 +429,6 @@ class OutputProcessor:
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats)
self.tracing_enabled = tracing_enabled
self._requests_drained = asyncio.Event()
self._requests_drained.set()
def get_num_unfinished_requests(self):
return len(self.request_states)
......@@ -438,11 +436,6 @@ class OutputProcessor:
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
async def wait_for_requests_to_drain(self) -> None:
if not self.request_states:
return
await self._requests_drained.wait()
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
......@@ -510,8 +503,6 @@ class OutputProcessor:
child_reqs = self.abort_requests(child_reqs, internal=True)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
if not self.request_states:
self._requests_drained.set()
return request_ids_to_abort
def add_request(
......@@ -538,8 +529,6 @@ class OutputProcessor:
log_stats=self.log_stats,
stream_interval=self.stream_interval,
)
if self._requests_drained.is_set():
self._requests_drained.clear()
self.request_states[request_id] = req_state
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
......@@ -706,9 +695,6 @@ class OutputProcessor:
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment