Unverified Commit c75363fb authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)

parent dd3fa0e4
import asyncio import asyncio
import os import os
from asyncio import CancelledError
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
import pytest_asyncio
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear from ..utils import wait_for_gpu_memory_to_clear
...@@ -118,15 +123,38 @@ async def test_new_requests_event(): ...@@ -118,15 +123,38 @@ async def test_new_requests_event():
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY") os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
def test_asyncio_run(): def start_engine():
wait_for_gpu_memory_to_clear( wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())), devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30, threshold_bytes=2 * 2**30,
timeout_s=60, timeout_s=60,
) )
engine = AsyncLLMEngine.from_engine_args( return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m")) AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
async def run(prompt: str): async def run(prompt: str):
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -134,17 +162,64 @@ def test_asyncio_run(): ...@@ -134,17 +162,64 @@ def test_asyncio_run():
max_tokens=32, max_tokens=32,
) )
async for output in engine.generate(prompt, async for output in async_engine.generate(prompt,
sampling_params, sampling_params,
request_id=prompt): request_id=prompt):
final_output = output final_output = output
return final_output return final_output
async def generate(): results = await asyncio.gather(
return await asyncio.gather( run("test0"),
run("test0"), run("test1"),
run("test1"), )
)
results = asyncio.run(generate())
assert len(results) == 2 assert len(results) == 2
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
assert i == 5
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1
assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished
...@@ -2,8 +2,8 @@ import asyncio ...@@ -2,8 +2,8 @@ import asyncio
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
import torch import torch
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -85,9 +85,8 @@ class AsyncStream: ...@@ -85,9 +85,8 @@ class AsyncStream:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None: Exception]) -> None:
if self._finished: if not self._finished:
return self._queue.put_nowait(item)
self._queue.put_nowait(item)
def finish( def finish(
self, self,
...@@ -96,7 +95,7 @@ class AsyncStream: ...@@ -96,7 +95,7 @@ class AsyncStream:
if not self._finished: if not self._finished:
self._finished = True self._finished = True
self._queue.put_nowait( self._queue.put_nowait(
exception if exception is not None else STOP_ITERATION) exception if self._is_raisable(exception) else STOP_ITERATION)
@property @property
def finished(self) -> bool: def finished(self) -> bool:
...@@ -106,9 +105,9 @@ class AsyncStream: ...@@ -106,9 +105,9 @@ class AsyncStream:
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try: try:
while not self._finished: while True:
result = await self._queue.get() result = await self._queue.get()
if isinstance(result, Exception): if self._is_raisable(result):
if result == STOP_ITERATION: if result == STOP_ITERATION:
return return
raise result raise result
...@@ -117,6 +116,12 @@ class AsyncStream: ...@@ -117,6 +116,12 @@ class AsyncStream:
self._cancel(self.request_id) self._cancel(self.request_id)
raise asyncio.CancelledError from None raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker: class RequestTracker:
"""Synchronous abstraction for tracking requests.""" """Synchronous abstraction for tracking requests."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment