"tests/vscode:/vscode.git/clone" did not exist on "99dac099ab5205d40bfaf5cf5652884b8764a400"
Unverified Commit 58fcc854 authored by Adam Lugowski's avatar Adam Lugowski Committed by GitHub
Browse files

[Frontend] Add progress reporting to run_batch.py (#8060)


Co-authored-by: default avatarAdam Lugowski <adam.lugowski@parasail.io>
parent 08287ef6
import asyncio import asyncio
from io import StringIO from io import StringIO
from typing import Awaitable, Callable, List from typing import Awaitable, Callable, List, Optional
import aiohttp import aiohttp
import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -78,6 +80,38 @@ def parse_args(): ...@@ -78,6 +80,38 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar
async def read_file(path_or_url: str) -> str: async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"): if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \ async with aiohttp.ClientSession() as session, \
...@@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None: ...@@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None:
async def run_request(serving_engine_func: Callable, async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput: request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body) response = await serving_engine_func(request.body)
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
...@@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable, ...@@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable,
else: else:
raise ValueError("Request must not be sent in stream mode") raise ValueError("Request must not be sent in stream mode")
tracker.completed()
return batch_output return batch_output
...@@ -164,6 +200,9 @@ async def main(args): ...@@ -164,6 +200,9 @@ async def main(args):
request_logger=request_logger, request_logger=request_logger,
) )
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently". # Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = [] response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"): for request_json in (await read_file(args.input_file)).strip().split("\n"):
...@@ -178,15 +217,18 @@ async def main(args): ...@@ -178,15 +217,18 @@ async def main(args):
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
response_futures.append( response_futures.append(
run_request(openai_serving_chat.create_chat_completion, run_request(openai_serving_chat.create_chat_completion,
request)) request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
response_futures.append( response_futures.append(
run_request(openai_serving_embedding.create_embedding, run_request(openai_serving_embedding.create_embedding, request,
request)) tracker))
tracker.submitted()
else: else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are" raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.") "supported in the batch endpoint.")
with tracker.pbar():
responses = await asyncio.gather(*response_futures) responses = await asyncio.gather(*response_futures)
output_buffer = StringIO() output_buffer = StringIO()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment