Unverified Commit c9927c1a authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Use queue for finished requests (#957)

parent fbd80ad4
......@@ -156,8 +156,8 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_beam_search(
self,
......
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
......@@ -152,7 +152,7 @@ class AsyncLLMEngine:
# Request id -> stream.
self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: Set[str] = set()
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.background_loop = None
if start_engine_loop:
self.start_background_loop()
......@@ -194,12 +194,14 @@ class AsyncLLMEngine:
if self.log_requests:
logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)
await self._engine_abort(self.finished_requests)
for request_id in self.finished_requests:
finished_request = set()
while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id]
self.finished_requests.clear()
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
......@@ -226,6 +228,8 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
if request_id in self.request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self.request_streams[request_id] = stream
......@@ -316,7 +320,7 @@ class AsyncLLMEngine:
logger.info(f"Aborted request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment