Unverified Commit 4b4a67f8 authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

feat: support TRT LLM benchmark and multiple benchmarks (#670)

parent 0ac94c36
......@@ -19,6 +19,7 @@ import traceback
import warnings
from argparse import ArgumentParser as FlexibleArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple, Union
import aiohttp
......@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
# set ignore_eos True by default
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
......@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
}
......@@ -449,6 +517,7 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
disable_tqdm: bool,
enable_multi: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
......@@ -542,6 +611,37 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50)
if enable_multi:
if (
metrics.median_ttft_ms is not None
and metrics.mean_itl_ms is not None
and metrics.output_throughput is not None
):
result = {
"dataset_name": args.dataset_name,
"request_rate": request_rate,
"median_ttft": metrics.median_ttft_ms,
"median_itl": metrics.mean_itl_ms,
"output_token_throughput": metrics.output_throughput,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
print("-" * 30)
# Determine output file name
if args.output_file:
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d%H")
output_file_name = f"{args.backend}_{now}.jsonl"
# Append results to a JSONL file
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
......@@ -572,6 +672,11 @@ async def benchmark(
return result
def parse_request_rate_range(request_rate_range):
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
def fire(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
......@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace):
"sglang": 30000,
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
}.get(args.backend, 30000)
api_url = (
......@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace):
else f"http://{args.host}:{args.port}/v1/models"
)
if args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
)
if args.model is None:
print("Please provide a model using `--model` when using `trt` backend.")
sys.exit(1)
if args.model is None:
try:
response = requests.get(model_url)
......@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace):
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
if args.multi:
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
)
)
else:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
)
)
)
# to avoid relying on SGLang's components
......@@ -751,6 +885,18 @@ if __name__ == "__main__":
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
set_ulimit()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment