Unverified Commit 6f547829 authored by Seiji Eicher's avatar Seiji Eicher Committed by GitHub
Browse files

Use `aiohttp` connection pool for benchmarking (#21981)


Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
parent 6a39ba85
...@@ -50,6 +50,7 @@ class RequestFuncOutput: ...@@ -50,6 +50,7 @@ class RequestFuncOutput:
async def async_request_openai_completions( async def async_request_openai_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
"""The async request function for the OpenAI Completions API. """The async request function for the OpenAI Completions API.
...@@ -66,96 +67,94 @@ async def async_request_openai_completions( ...@@ -66,96 +67,94 @@ async def async_request_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, payload = {
timeout=AIOHTTP_TIMEOUT) as session: "model": request_func_input.model_name \
payload = { if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name \ "prompt": request_func_input.prompt,
if request_func_input.model_name else request_func_input.model, "temperature": 0.0,
"prompt": request_func_input.prompt, "repetition_penalty": 1.0,
"temperature": 0.0, "max_tokens": request_func_input.output_len,
"repetition_penalty": 1.0, "logprobs": request_func_input.logprobs,
"max_tokens": request_func_input.output_len, "stream": True,
"logprobs": request_func_input.logprobs, "stream_options": {
"stream": True, "include_usage": True,
"stream_options": { },
"include_usage": True, }
}, if request_func_input.ignore_eos:
} payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.ignore_eos: if request_func_input.extra_body:
payload["ignore_eos"] = request_func_input.ignore_eos payload.update(request_func_input.extra_body)
if request_func_input.extra_body: headers = {
payload.update(request_func_input.extra_body) "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
headers = { }
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
} output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len generated_text = ""
st = time.perf_counter()
generated_text = "" most_recent_timestamp = st
st = time.perf_counter() try:
most_recent_timestamp = st async with session.post(url=api_url, json=payload,
try: headers=headers) as response:
async with session.post(url=api_url, json=payload, if response.status == 200:
headers=headers) as response: first_chunk_received = False
if response.status == 200: async for chunk_bytes in response.content:
first_chunk_received = False chunk_bytes = chunk_bytes.strip()
async for chunk_bytes in response.content: if not chunk_bytes:
chunk_bytes = chunk_bytes.strip() continue
if not chunk_bytes: chunk_bytes = chunk_bytes.decode("utf-8")
continue # NOTE: SSE comments (often used as pings) start with
chunk_bytes = chunk_bytes.decode("utf-8") # a colon. These are not JSON data payload and should
# NOTE: SSE comments (often used as pings) start with # be skipped.
# a colon. These are not JSON data payload and should if chunk_bytes.startswith(":"):
# be skipped. continue
if chunk_bytes.startswith(":"):
continue chunk = chunk_bytes.removeprefix("data: ")
chunk = chunk_bytes.removeprefix("data: ") if chunk != "[DONE]":
data = json.loads(chunk)
if chunk != "[DONE]":
data = json.loads(chunk) # NOTE: Some completion API might have a last
# usage summary response without a token so we
# NOTE: Some completion API might have a last # want to check a token was generated
# usage summary response without a token so we if choices := data.get("choices"):
# want to check a token was generated # Note that text could be empty here
if choices := data.get("choices"): # e.g. for special tokens
# Note that text could be empty here text = choices[0].get("text")
# e.g. for special tokens timestamp = time.perf_counter()
text = choices[0].get("text") # First token
timestamp = time.perf_counter() if not first_chunk_received:
# First token first_chunk_received = True
if not first_chunk_received: ttft = time.perf_counter() - st
first_chunk_received = True output.ttft = ttft
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp -
most_recent_timestamp) most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get(
"completion_tokens") "completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else: else:
output.error = response.reason or ""
output.success = False output.success = False
except Exception: output.error = (
output.success = False "Never received a valid chunk to calculate TTFT."
exc_info = sys.exc_info() "This response will be marked as failed!")
output.error = "".join(traceback.format_exception(*exc_info)) output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -164,45 +163,158 @@ async def async_request_openai_completions( ...@@ -164,45 +163,158 @@ async def async_request_openai_completions(
async def async_request_openai_chat_completions( async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith(("chat/completions", "profile")), ( assert api_url.endswith(("chat/completions", "profile")), (
"OpenAI Chat Completions API URL must end with 'chat/completions'.") "OpenAI Chat Completions API URL must end with 'chat/completions'.")
async with aiohttp.ClientSession(trust_env=True, content = [{"type": "text", "text": request_func_input.prompt}]
timeout=AIOHTTP_TIMEOUT) as session: if request_func_input.multi_modal_content:
content = [{"type": "text", "text": request_func_input.prompt}] content.append(request_func_input.multi_modal_content)
if request_func_input.multi_modal_content: payload = {
content.append(request_func_input.multi_modal_content) "model":
payload = { request_func_input.model_name
"model": if request_func_input.model_name else request_func_input.model,
request_func_input.model_name "messages": [
if request_func_input.model_name else request_func_input.model, {
"messages": [ "role": "user",
{ "content": content
"role": "user",
"content": content
},
],
"temperature":
0.0,
"max_completion_tokens":
request_func_input.output_len,
"stream":
True,
"stream_options": {
"include_usage": True,
}, },
} ],
if request_func_input.ignore_eos: "temperature":
payload["ignore_eos"] = request_func_input.ignore_eos 0.0,
if request_func_input.extra_body: "max_completion_tokens":
payload.update(request_func_input.extra_body) request_func_input.output_len,
headers = { "stream":
"Content-Type": "application/json", True,
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "stream_options": {
} "include_usage": True,
},
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk_bytes = chunk_bytes.decode("utf-8")
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if chunk_bytes.startswith(":"):
continue
chunk = chunk_bytes.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(("transcriptions", "translations")), (
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
"or `translations`."
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model":
request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"temperature":
0.0,
"max_completion_tokens":
request_func_input.output_len,
"stream":
True,
"language":
"en",
# Flattened due to multipart/form-data
"stream_include_usage":
True,
"stream_continuous_usage_stats":
True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -212,28 +324,24 @@ async def async_request_openai_chat_completions( ...@@ -212,28 +324,24 @@ async def async_request_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url,
data=form,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk_bytes = chunk_bytes.decode("utf-8")
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if chunk_bytes.startswith(":"):
continue
chunk = chunk_bytes.removeprefix("data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get("content") content = choices[0]["delta"].get(
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
...@@ -241,8 +349,8 @@ async def async_request_openai_chat_completions( ...@@ -241,8 +349,8 @@ async def async_request_openai_chat_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(
most_recent_timestamp) timestamp - most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
...@@ -267,117 +375,6 @@ async def async_request_openai_chat_completions( ...@@ -267,117 +375,6 @@ async def async_request_openai_chat_completions(
return output return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(("transcriptions", "translations")), (
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model":
request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"temperature":
0.0,
"max_completion_tokens":
request_func_input.output_len,
"stream":
True,
"language":
"en",
# Flattened due to multipart/form-data
"stream_include_usage":
True,
"stream_continuous_usage_stats":
True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols. # TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
......
...@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput ...@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
async def wait_for_endpoint( async def wait_for_endpoint(
request_func, request_func,
test_input: RequestFuncInput, test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600, timeout_seconds: int = 600,
retry_interval: int = 5, retry_interval: int = 5,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
...@@ -55,7 +56,8 @@ async def wait_for_endpoint( ...@@ -55,7 +56,8 @@ async def wait_for_endpoint(
# ping the endpoint using request_func # ping the endpoint using request_func
try: try:
output = await request_func(request_func_input=test_input) output = await request_func(
request_func_input=test_input, session=session)
if output.success: if output.success:
pbar.close() pbar.close()
return output return output
......
...@@ -28,6 +28,7 @@ from dataclasses import dataclass ...@@ -28,6 +28,7 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -338,6 +339,24 @@ async def benchmark( ...@@ -338,6 +339,24 @@ async def benchmark(
else: else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}") raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
# Reuses connections across requests to reduce TLS handshake overhead.
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
)
session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
)
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = ( test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt, input_requests[0].prompt,
...@@ -361,7 +380,11 @@ async def benchmark( ...@@ -361,7 +380,11 @@ async def benchmark(
) )
test_output = await wait_for_endpoint( test_output = await wait_for_endpoint(
request_func, test_input, timeout_seconds=ready_check_timeout_sec) request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
...@@ -386,7 +409,8 @@ async def benchmark( ...@@ -386,7 +409,8 @@ async def benchmark(
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
...@@ -412,12 +436,14 @@ async def benchmark( ...@@ -412,12 +436,14 @@ async def benchmark(
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None) if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, session, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar) pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar) pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
...@@ -469,6 +495,7 @@ async def benchmark( ...@@ -469,6 +495,7 @@ async def benchmark(
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
session=session,
pbar=pbar))) pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
...@@ -580,9 +607,12 @@ async def benchmark( ...@@ -580,9 +607,12 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success: if profile_output.success:
print("Profiler stopped") print("Profiler stopped")
await session.close()
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment