Unverified Commit 45b6ef65 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

feat(benchmarks): Add Prefix Caching Benchmark to Serving Benchmark (#3277)

parent 19569314
...@@ -23,8 +23,9 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r ...@@ -23,8 +23,9 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
# wait for server to start, timeout after 600 seconds # wait for server to start, timeout after 600 seconds
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \ python3 benchmarks/benchmark_serving.py \
--backend openai \ --backend vllm \
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \ --dataset-name sharegpt \
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
--model meta-llama/Llama-2-7b-chat-hf \ --model meta-llama/Llama-2-7b-chat-hf \
--num-prompts 20 \ --num-prompts 20 \
--endpoint /v1/completions \ --endpoint /v1/completions \
......
import json import json
import os import os
import sys
import time import time
from dataclasses import dataclass import traceback
from typing import Optional from dataclasses import dataclass, field
from typing import List, Optional
import aiohttp import aiohttp
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
...@@ -26,8 +28,11 @@ class RequestFuncOutput: ...@@ -26,8 +28,11 @@ class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0 latency: float = 0
ttft: float = 0 ttft: float = 0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = ""
async def async_request_tgi( async def async_request_tgi(
...@@ -55,71 +60,38 @@ async def async_request_tgi( ...@@ -55,71 +60,38 @@ async def async_request_tgi(
ttft = 0 ttft = 0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for data in response.content.iter_any(): async for chunk in response.content:
if ttft == 0: chunk = chunk.strip()
ttft = time.perf_counter() - st if not chunk:
output.ttft = ttft continue
output.latency = time.perf_counter() - st
body = remove_prefix(data.decode("utf-8"), "data:")
output.generated_text = json.loads(body)["generated_text"]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_vllm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: chunk = remove_prefix(chunk.decode("utf-8"), "data:")
payload = {
"prompt": request_func_input.prompt,
"n": 1,
"best_of": request_func_input.best_of,
"use_beam_search": request_func_input.use_beam_search,
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"ignore_eos": True,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0 data = json.loads(chunk)
st = time.perf_counter() timestamp = time.perf_counter()
try: # First token
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for data in response.content.iter_any():
if ttft == 0: if ttft == 0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
output.latency = time.perf_counter() - st
# When streaming, '\0' is appended to the end of response. # Decoding phase
body = data.decode("utf-8").strip("\0") else:
output.generated_text = json.loads( output.itl.append(timestamp -
body)["text"][0][len(request_func_input.prompt):] most_recent_timestamp)
output.success = True
else: most_recent_timestamp = timestamp
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): output.latency = most_recent_timestamp - st
output.success = True
output.generated_text = data["generated_text"]
except Exception:
output.success = False output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -146,26 +118,45 @@ async def async_request_trt_llm( ...@@ -146,26 +118,45 @@ async def async_request_trt_llm(
} }
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as resp: async with session.post(url=api_url, json=payload) as response:
if resp.status == 200: if response.status == 200:
async for data in resp.content.iter_any(): async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0: if ttft == 0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
output.latency = time.perf_counter() - st
body = remove_prefix(data.decode("utf-8"), "data:") # Decoding phase
output.generated_text = json.loads(body)["text_output"] else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.generated_text = json.loads(data)["text_output"]
output.success = True output.success = True
else: else:
output.error = response.reason
output.success = False output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): except Exception:
output.success = False output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -181,35 +172,35 @@ async def async_request_deepspeed_mii( ...@@ -181,35 +172,35 @@ async def async_request_deepspeed_mii(
assert not request_func_input.use_beam_search assert not request_func_input.use_beam_search
payload = { payload = {
"prompts": request_func_input.prompt, "prompt": request_func_input.prompt,
"max_new_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"ignore_eos": True, "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
"do_sample": True,
"temperature":
0.01, # deepspeed-mii does not accept 0.0 temperature.
"top_p": 1.0, "top_p": 1.0,
} }
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder. # will use 0 as placeholder.
# https://github.com/microsoft/DeepSpeed-MII/pull/311 # See https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0 output.ttft = 0
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(url=request_func_input.api_url,
json=payload) as resp: json=payload) as response:
if resp.status == 200: if response.status == 200:
parsed_resp = await resp.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
output.generated_text = parsed_resp[0]["generated_text"] output.generated_text = parsed_resp["text"][0]
output.success = True output.success = True
else: else:
output.error = response.reason
output.success = False output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): except Exception:
output.success = False output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -221,7 +212,9 @@ async def async_request_openai_completions( ...@@ -221,7 +212,9 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("v1/completions") assert api_url.endswith(
"v1/completions"
), "OpenAI Completions API URL must end with 'v1/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search assert not request_func_input.use_beam_search
...@@ -243,15 +236,12 @@ async def async_request_openai_completions( ...@@ -243,15 +236,12 @@ async def async_request_openai_completions(
generated_text = "" generated_text = ""
ttft = 0 ttft = 0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk.strip() chunk = chunk.strip()
if not chunk: if not chunk:
continue continue
...@@ -260,16 +250,33 @@ async def async_request_openai_completions( ...@@ -260,16 +250,33 @@ async def async_request_openai_completions(
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
body = json.loads(chunk) data = json.loads(chunk)
generated_text += body["choices"][0]["text"]
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# do not want to include as inter-token-latency
elif data.get("usage", None) is None:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
else: except Exception:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -283,7 +290,7 @@ async def async_request_openai_chat_completions( ...@@ -283,7 +290,7 @@ async def async_request_openai_chat_completions(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(
"v1/chat/completions" "v1/chat/completions"
), "OpenAI Chat API URL must end with 'v1/chat/completions'." ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search assert not request_func_input.use_beam_search
...@@ -301,7 +308,7 @@ async def async_request_openai_chat_completions( ...@@ -301,7 +308,7 @@ async def async_request_openai_chat_completions(
} }
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
output = RequestFuncOutput() output = RequestFuncOutput()
...@@ -310,15 +317,12 @@ async def async_request_openai_chat_completions( ...@@ -310,15 +317,12 @@ async def async_request_openai_chat_completions(
generated_text = "" generated_text = ""
ttft = 0 ttft = 0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk.strip() chunk = chunk.strip()
if not chunk: if not chunk:
continue continue
...@@ -327,18 +331,35 @@ async def async_request_openai_chat_completions( ...@@ -327,18 +331,35 @@ async def async_request_openai_chat_completions(
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
body = json.loads(chunk) timestamp = time.perf_counter()
if "content" in body["choices"][0]["delta"]: data = json.loads(chunk)
generated_text += body["choices"][0]["delta"][
if "content" in data["choices"][0]["delta"]:
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += data["choices"][0]["delta"][
"content"] "content"]
most_recent_timestamp = timestamp
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
else: else:
output.error = response.reason
output.success = False output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): except Exception:
output.success = False output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str: ...@@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str:
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi, "tgi": async_request_tgi,
"vllm": async_request_vllm, "vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"deepspeed-mii": async_request_deepspeed_mii, "deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions, "openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions, "openai-chat": async_request_openai_chat_completions,
......
"""Benchmark online serving throughput. """Benchmark online serving throughput.
On the server side, run one of the following commands: On the server side, run one of the following commands:
(vLLM backend) vLLM OpenAI API server
python -m vllm.entrypoints.api_server \ python -m vllm.entrypoints.openai.api_server \
--model <your_model> --swap-space 16 \ --model <your_model> --swap-space 16 \
--disable-log-requests --disable-log-requests
...@@ -12,14 +12,19 @@ On the server side, run one of the following commands: ...@@ -12,14 +12,19 @@ On the server side, run one of the following commands:
On the client side, run: On the client side, run:
python benchmarks/benchmark_serving.py \ python benchmarks/benchmark_serving.py \
--backend <backend> \ --backend <backend> \
--model <your_model> --dataset <target_dataset> \ --model <your_model> \
--request-rate <request_rate> --dataset-name sharegpt \
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000
""" """
import argparse import argparse
import asyncio import asyncio
import json import json
import os
import random import random
import time import time
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
...@@ -49,7 +54,7 @@ class BenchmarkMetrics: ...@@ -49,7 +54,7 @@ class BenchmarkMetrics:
p99_tpot_ms: float p99_tpot_ms: float
def sample_requests( def sample_sharegpt_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
...@@ -97,6 +102,73 @@ def sample_requests( ...@@ -97,6 +102,73 @@ def sample_requests(
return sampled_requests return sampled_requests
def sample_sonnet_requests(
dataset_path: str,
num_requests: int,
input_len: int,
output_len: int,
prefix_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int]]:
assert input_len > prefix_len, "input_len must be greater than prefix_len."
# Load the dataset.
with open(dataset_path) as f:
poem_lines = f.readlines()
# Tokenize the poem lines.
poem_token_ids = tokenizer(poem_lines).input_ids
average_poem_len = sum(
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
# Base prefix for all requests.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_message = [{
"role": "user",
"content": base_prompt,
}]
base_prompt_formatted = tokenizer.apply_chat_template(
base_message, add_generation_prompt=True, tokenize=False)
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
assert (input_len > base_prompt_offset
), f"Please set 'args.input-len' higher than {base_prompt_offset}."
num_input_lines = round(
(input_len - base_prompt_offset) / average_poem_len)
# First approximately `prefix_len` number of tokens in the
# prompt are fixed poem lines.
assert (
prefix_len > base_prompt_offset
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
num_prefix_lines = round(
(prefix_len - base_prompt_offset) / average_poem_len)
prefix_lines = poem_lines[:num_prefix_lines]
# Sample the rest of lines per request.
sampled_requests: List[Tuple[str, int, int]] = []
for _ in range(num_requests):
sampled_lines = "".join(
prefix_lines +
random.sample(poem_lines, num_input_lines - num_prefix_lines))
prompt = f"{base_prompt}{sampled_lines}"
message = [
{
"role": "user",
"content": prompt,
},
]
prompt_formatted = tokenizer.apply_chat_template(
message, add_generation_prompt=True, tokenize=False)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
sampled_requests.append(
(prompt, prompt_formatted, prompt_len, output_len))
return sampled_requests
async def get_request( async def get_request(
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
request_rate: float, request_rate: float,
...@@ -119,37 +191,42 @@ def calculate_metrics( ...@@ -119,37 +191,42 @@ def calculate_metrics(
outputs: List[RequestFuncOutput], outputs: List[RequestFuncOutput],
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> BenchmarkMetrics: ) -> Tuple[BenchmarkMetrics, List[int]]:
total_output = 0 actual_output_lens = []
total_input = 0 total_input = 0
completed = 0 completed = 0
per_token_latencies = [] tpots = []
ttfts = [] ttfts = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
output_len = len(tokenizer.encode(outputs[i].generated_text)) output_len = len(tokenizer(outputs[i].generated_text).input_ids)
total_output += output_len actual_output_lens.append(output_len)
total_input += input_requests[i][1] total_input += input_requests[i][1]
per_token_latencies.append(outputs[i].latency / output_len) if output_len > 1:
tpots.append(
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
ttfts.append(outputs[i].ttft) ttfts.append(outputs[i].ttft)
completed += 1 completed += 1
else:
actual_output_lens.append(0)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
total_output=total_output, total_output=sum(actual_output_lens),
request_throughput=completed / dur_s, request_throughput=completed / dur_s,
input_throughput=total_input / dur_s, input_throughput=total_input / dur_s,
output_throughput=total_output / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
mean_ttft_ms=np.mean(ttfts) * 1000, mean_ttft_ms=np.mean(ttfts or 0) *
median_ttft_ms=np.median(ttfts) * 1000, 1000, # ttfts is empty if streaming is not supported by backend
p99_ttft_ms=np.percentile(ttfts, 99) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
mean_tpot_ms=np.mean(per_token_latencies) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
median_tpot_ms=np.median(per_token_latencies) * 1000, mean_tpot_ms=np.mean(tpots) * 1000,
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000, median_tpot_ms=np.median(tpots) * 1000,
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
) )
return metrics return metrics, actual_output_lens
async def benchmark( async def benchmark(
...@@ -189,40 +266,53 @@ async def benchmark( ...@@ -189,40 +266,53 @@ async def benchmark(
asyncio.create_task( asyncio.create_task(
request_func(request_func_input=request_func_input, request_func(request_func_input=request_func_input,
pbar=pbar))) pbar=pbar)))
outputs = await asyncio.gather(*tasks) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if not disable_tqdm: if not disable_tqdm:
pbar.close() pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
metrics = calculate_metrics( metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests, input_requests=input_requests,
outputs=outputs, outputs=outputs,
dur_s=benchmark_duration, dur_s=benchmark_duration,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
print(f"Successful requests: {metrics.completed}") print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print(f"Benchmark duration: {benchmark_duration:2f} s") print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print(f"Total input tokens: {metrics.total_input}") print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
print(f"Total generated tokens: {metrics.total_output}") benchmark_duration))
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s") print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s") print("{:<40} {:<10}".format("Total generated tokens:",
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s") metrics.total_output))
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms") print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms") metrics.request_throughput))
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms") print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms") metrics.input_throughput))
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms") print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms") metrics.output_throughput))
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
n=50,
c='-'))
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("=" * 50)
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
"completed": metrics.completed, "completed": metrics.completed,
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_inthroughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput, "input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms, "mean_ttft_ms": metrics.mean_ttft_ms,
...@@ -230,7 +320,13 @@ async def benchmark( ...@@ -230,7 +320,13 @@ async def benchmark(
"p99_ttft_ms": metrics.p99_ttft_ms, "p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms, "mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms "p99_tpot_ms": metrics.p99_tpot_ms,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
} }
return result return result
...@@ -251,7 +347,58 @@ def main(args: argparse.Namespace): ...@@ -251,7 +347,58 @@ def main(args: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id, tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next "
"release. Please use '--dataset-name' and "
"'--dataset-path' in the future runs.",
stacklevel=2)
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset,
num_requests=args.num_prompts,
tokenizer=tokenizer,
)
elif args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
)
elif args.dataset_name == "sonnet":
# Do not format the prompt, pass to message directly
if args.backend == "openai-chat":
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
input_len=args.input_len,
output_len=args.output_len,
prefix_len=args.prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt, prompt_len, output_len)
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]
else:
assert (
tokenizer.chat_template or tokenizer.default_chat_template
), "Tokenizer/model must have chat template for sonnet dataset."
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
input_len=args.input_len,
output_len=args.output_len,
prefix_len=args.prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt_formatted, prompt_len, output_len)
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
benchmark_result = asyncio.run( benchmark_result = asyncio.run(
benchmark( benchmark(
...@@ -274,13 +421,23 @@ def main(args: argparse.Namespace): ...@@ -274,13 +421,23 @@ def main(args: argparse.Namespace):
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
result_json["date"] = current_dt result_json["date"] = current_dt
result_json["backend"] = backend result_json["backend"] = backend
result_json["version"] = args.version
result_json["model_id"] = model_id result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts result_json["num_prompts"] = args.num_prompts
# Metadata
if args.metadata:
for item in args.metadata:
if "=" in item:
kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
# Traffic # Traffic
result_json["request_rate"] = ( result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf") args.request_rate if args.request_rate < float("inf") else "inf")
...@@ -290,9 +447,9 @@ def main(args: argparse.Namespace): ...@@ -290,9 +447,9 @@ def main(args: argparse.Namespace):
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
file_name = ( file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" if args.result_dir:
) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w") as outfile: with open(file_name, "w") as outfile:
json.dump(result_json, outfile) json.dump(result_json, outfile)
...@@ -306,12 +463,6 @@ if __name__ == "__main__": ...@@ -306,12 +463,6 @@ if __name__ == "__main__":
default="vllm", default="vllm",
choices=list(ASYNC_REQUEST_FUNCS.keys()), choices=list(ASYNC_REQUEST_FUNCS.keys()),
) )
parser.add_argument(
"--version",
type=str,
default="N/A",
help="Version of the serving backend/engine.",
)
parser.add_argument( parser.add_argument(
"--base-url", "--base-url",
type=str, type=str,
...@@ -323,12 +474,26 @@ if __name__ == "__main__": ...@@ -323,12 +474,26 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--endpoint", "--endpoint",
type=str, type=str,
default="/generate", default="/v1/completions",
help="API endpoint.", help="API endpoint.",
) )
parser.add_argument("--dataset", parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in the "
"next release.",
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "sonnet"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
type=str, type=str,
required=True, default=None,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument( parser.add_argument(
"--model", "--model",
...@@ -356,6 +521,27 @@ if __name__ == "__main__": ...@@ -356,6 +521,27 @@ if __name__ == "__main__":
default=1000, default=1000,
help="Number of prompts to process.", help="Number of prompts to process.",
) )
parser.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
type=float, type=float,
...@@ -381,6 +567,21 @@ if __name__ == "__main__": ...@@ -381,6 +567,21 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Specify to save benchmark results to a json file", help="Specify to save benchmark results to a json file",
) )
parser.add_argument(
"--metadata",
metavar="KEY=VALUE",
nargs="*",
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
"for metadata of this run to be saved in the result JSON file "
"for record keeping purposes.",
)
parser.add_argument(
"--result-dir",
type=str,
default=None,
help="Specify directory to save benchmark json results."
"If not specified, results are saved in the current directory.",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
This diff is collapsed.
...@@ -50,7 +50,7 @@ exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" ...@@ -50,7 +50,7 @@ exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
[tool.codespell] [tool.codespell]
ignore-words-list = "dout, te, indicies" ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts" skip = "./tests/prompts,./benchmarks/sonnet.txt"
[tool.isort] [tool.isort]
use_parentheses = true use_parentheses = true
......
...@@ -36,8 +36,8 @@ def test_contexted_kv_attention( ...@@ -36,8 +36,8 @@ def test_contexted_kv_attention(
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process for GPU 1 would # Need this, otherwise when we capture the graph the process
# run on both GPU0 and GPU1 and things would hang # for GPU 1 would run on both GPU0 and GPU1 and things would hang
# #
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment