Unverified Commit 7fd3949a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend][Core] Move `merge_async_iterators` to utils (#4026)

parent 1096717a
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
......@@ -17,7 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import random_uuid
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
......@@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
......
......@@ -9,8 +9,8 @@ import warnings
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
Optional, Tuple, TypeVar, Union)
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, Tuple, TypeVar, Union)
import psutil
import torch
......@@ -181,6 +181,42 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
return _async_wrapper
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
if host_ip:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment