Unverified Commit ad0297d1 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Misc] Support passing multiple request ids at once to `AsyncLLM.abort()` (#22944)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 236b864e
...@@ -212,6 +212,79 @@ async def test_abort( ...@@ -212,6 +212,79 @@ async def test_abort(
assert not engine.output_processor.has_unfinished_requests() assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_multi_abort(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
NUM_REQUESTS = 50
NUM_EXPECTED_TOKENS = 100
NUM_EXPECTED_TOKENS_LONG = 50000
REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
(idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
generate(engine, request_id, TEXT_PROMPT, output_kind,
max_tokens, n)))
# Let requests start
await asyncio.sleep(0.5)
# Use multi-abort to abort multiple requests at once
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
await engine.abort(abort_request_ids)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify results
for idx, result in enumerate(results):
if idx in REQUEST_IDS_TO_ABORT:
# Aborted requests should return partial results
assert isinstance(
result, tuple
), f"Request {idx} should have completed with partial results"
num_generated_tokens, request_id = result
# Should have generated some tokens before abort
assert num_generated_tokens > 0, (
f"Aborted request "
f"{request_id} should have generated some tokens")
else:
# Non-aborted requests should complete normally
assert isinstance(
result,
tuple), f"Request {idx} should have completed successfully"
num_generated_tokens, request_id = result
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
expected_tokens = NUM_EXPECTED_TOKENS * n
assert num_generated_tokens == expected_tokens, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {expected_tokens}")
# Make sure all aborted requests were cleaned up
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"engine_args,prompt", "engine_args,prompt",
...@@ -460,7 +533,9 @@ async def test_abort_final_output( ...@@ -460,7 +533,9 @@ async def test_abort_final_output(
token_count = sum( token_count = sum(
len(output.outputs[0].token_ids) for output in outputs) len(output.outputs[0].token_ids) for output in outputs)
assert token_count > 0 assert token_count > 0
assert len(final_output.outputs[0].token_ids) == 0 # This would ordinarily be 0, but could end up > 0 if the
# final abort is coalesced with another chunk in the output queue.
assert len(final_output.outputs[0].token_ids) >= 0
else: else:
# For FINAL_ONLY, we should only get the final output # For FINAL_ONLY, we should only get the final output
assert len(outputs) == 0 assert len(outputs) == 0
......
...@@ -998,7 +998,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -998,7 +998,7 @@ class AsyncLLMEngine(EngineClient):
await self.abort(request_id) await self.abort(request_id)
raise raise
async def abort(self, request_id: str) -> None: async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request. """Abort a request.
Abort a submitted request. If the request is finished or not found, Abort a submitted request. If the request is finished or not found,
...@@ -1007,6 +1007,9 @@ class AsyncLLMEngine(EngineClient): ...@@ -1007,6 +1007,9 @@ class AsyncLLMEngine(EngineClient):
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
if not isinstance(request_id, str):
raise RuntimeError("Only single-request abort supported in"
" deprecated V0")
if not self.is_running: if not self.is_running:
raise AsyncEngineDeadError( raise AsyncEngineDeadError(
"Background loop is not running. If it was running, " "Background loop is not running. If it was running, "
......
...@@ -5,8 +5,8 @@ import asyncio ...@@ -5,8 +5,8 @@ import asyncio
import copy import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
Optional, Union, cast) Mapping, Optional, Union, cast)
import cloudpickle import cloudpickle
import psutil import psutil
...@@ -404,9 +404,13 @@ class MQLLMEngineClient(EngineClient): ...@@ -404,9 +404,13 @@ class MQLLMEngineClient(EngineClient):
error_message="Unable to start RPC Server", error_message="Unable to start RPC Server",
socket=socket) socket=socket)
async def abort(self, request_id: str): async def abort(self, request_id: Union[str, Iterable[str]]):
"""Send an ABORT_REQUEST signal to the RPC Server""" """Send an ABORT_REQUEST signal to the RPC Server"""
if not isinstance(request_id, str):
raise RuntimeError("Only single-request abort supported in"
" deprecated V0")
with suppress(MQClientClosedError): with suppress(MQClientClosedError):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket) request=RPCAbortRequest(request_id), socket=self.input_socket)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Mapping, Optional from typing import AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
...@@ -229,11 +229,12 @@ class EngineClient(ABC): ...@@ -229,11 +229,12 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def abort(self, request_id: str) -> None: async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request. """Abort a request.
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request,
or an iterable of such ids.
""" """
... ...
......
...@@ -1315,6 +1315,11 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): ...@@ -1315,6 +1315,11 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
) )
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
# `collections` helpers # `collections` helpers
def is_list_of( def is_list_of(
value: object, value: object,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import time import time
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy from copy import copy
from typing import Any, Optional, Union from typing import Any, Optional, Union
...@@ -27,7 +27,8 @@ from vllm.transformers_utils.config import ( ...@@ -27,7 +27,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv,
deprecate_kwargs)
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
...@@ -431,14 +432,16 @@ class AsyncLLM(EngineClient): ...@@ -431,14 +432,16 @@ class AsyncLLM(EngineClient):
self.output_handler = asyncio.create_task(output_handler()) self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None: async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort RequestId in OutputProcessor and EngineCore.""" """Abort RequestId in OutputProcessor and EngineCore."""
request_ids = self.output_processor.abort_requests((request_id, )) request_ids = (request_id, ) if isinstance(
await self.engine_core.abort_requests_async(request_ids) request_id, str) else as_list(request_id)
all_request_ids = self.output_processor.abort_requests(request_ids)
await self.engine_core.abort_requests_async(all_request_ids)
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request_id) logger.info("Aborted request(s) %s.", ",".join(request_ids))
async def encode( async def encode(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment