Unverified Commit 6f547829 authored by Seiji Eicher's avatar Seiji Eicher Committed by GitHub
Browse files

Use `aiohttp` connection pool for benchmarking (#21981)


Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
parent 6a39ba85
...@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput ...@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
async def wait_for_endpoint( async def wait_for_endpoint(
request_func, request_func,
test_input: RequestFuncInput, test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600, timeout_seconds: int = 600,
retry_interval: int = 5, retry_interval: int = 5,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
...@@ -55,7 +56,8 @@ async def wait_for_endpoint( ...@@ -55,7 +56,8 @@ async def wait_for_endpoint(
# ping the endpoint using request_func # ping the endpoint using request_func
try: try:
output = await request_func(request_func_input=test_input) output = await request_func(
request_func_input=test_input, session=session)
if output.success: if output.success:
pbar.close() pbar.close()
return output return output
......
...@@ -28,6 +28,7 @@ from dataclasses import dataclass ...@@ -28,6 +28,7 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -338,6 +339,24 @@ async def benchmark( ...@@ -338,6 +339,24 @@ async def benchmark(
else: else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}") raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
# Reuses connections across requests to reduce TLS handshake overhead.
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
)
session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
)
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = ( test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt, input_requests[0].prompt,
...@@ -361,7 +380,11 @@ async def benchmark( ...@@ -361,7 +380,11 @@ async def benchmark(
) )
test_output = await wait_for_endpoint( test_output = await wait_for_endpoint(
request_func, test_input, timeout_seconds=ready_check_timeout_sec) request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
...@@ -386,7 +409,8 @@ async def benchmark( ...@@ -386,7 +409,8 @@ async def benchmark(
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
...@@ -412,12 +436,14 @@ async def benchmark( ...@@ -412,12 +436,14 @@ async def benchmark(
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None) if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, session, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar) pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar) pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
...@@ -469,6 +495,7 @@ async def benchmark( ...@@ -469,6 +495,7 @@ async def benchmark(
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
session=session,
pbar=pbar))) pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
...@@ -580,9 +607,12 @@ async def benchmark( ...@@ -580,9 +607,12 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success: if profile_output.success:
print("Profiler stopped") print("Profiler stopped")
await session.close()
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment