Unverified Commit ce75e153 authored by samzong's avatar samzong Committed by GitHub
Browse files

refactor(benchmarks): add type annotations to wait_for_endpoint parameters (#25218)


Signed-off-by: default avatarsamzong <samzong.lu@gmail.com>
parent aed16879
......@@ -8,8 +8,9 @@ import os
import sys
import time
import traceback
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Optional, Protocol, Union
import aiohttp
from tqdm.asyncio import tqdm
......@@ -92,6 +93,16 @@ class RequestFuncOutput:
start_time: float = 0.0
class RequestFunc(Protocol):
def __call__(
self,
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
) -> Awaitable[RequestFuncOutput]:
...
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
......@@ -507,7 +518,7 @@ async def async_request_openai_embeddings(
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
......
......@@ -8,11 +8,12 @@ import time
import aiohttp
from tqdm.asyncio import tqdm
from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
from .endpoint_request_func import (RequestFunc, RequestFuncInput,
RequestFuncOutput)
async def wait_for_endpoint(
request_func,
request_func: RequestFunc,
test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment