Unverified Commit 4b4a67f8 authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

feat: support TRT LLM benchmark and multiple benchmarks (#670)

parent 0ac94c36
...@@ -19,6 +19,7 @@ import traceback ...@@ -19,6 +19,7 @@ import traceback
import warnings import warnings
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple, Union from typing import AsyncGenerator, List, Optional, Tuple, Union
import aiohttp import aiohttp
...@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str: ...@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text return text[len(prefix) :] if text.startswith(prefix) else text
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
# set ignore_eos True by default # set ignore_eos True by default
async def async_request_openai_completions( async def async_request_openai_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
...@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = { ...@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions, "sglang": async_request_openai_completions,
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions, "lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
} }
...@@ -449,6 +517,7 @@ async def benchmark( ...@@ -449,6 +517,7 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
request_rate: float, request_rate: float,
disable_tqdm: bool, disable_tqdm: bool,
enable_multi: bool,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -542,6 +611,37 @@ async def benchmark( ...@@ -542,6 +611,37 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50) print("=" * 50)
if enable_multi:
if (
metrics.median_ttft_ms is not None
and metrics.mean_itl_ms is not None
and metrics.output_throughput is not None
):
result = {
"dataset_name": args.dataset_name,
"request_rate": request_rate,
"median_ttft": metrics.median_ttft_ms,
"median_itl": metrics.mean_itl_ms,
"output_token_throughput": metrics.output_throughput,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
print("-" * 30)
# Determine output file name
if args.output_file:
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d%H")
output_file_name = f"{args.backend}_{now}.jsonl"
# Append results to a JSONL file
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
"completed": metrics.completed, "completed": metrics.completed,
...@@ -572,6 +672,11 @@ async def benchmark( ...@@ -572,6 +672,11 @@ async def benchmark(
return result return result
def parse_request_rate_range(request_rate_range):
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
def fire(args: argparse.Namespace): def fire(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace): ...@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace):
"sglang": 30000, "sglang": 30000,
"lmdeploy": 23333, "lmdeploy": 23333,
"vllm": 8000, "vllm": 8000,
"trt": 8000,
}.get(args.backend, 30000) }.get(args.backend, 30000)
api_url = ( api_url = (
...@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace): ...@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace):
else f"http://{args.host}:{args.port}/v1/models" else f"http://{args.host}:{args.port}/v1/models"
) )
if args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
)
if args.model is None:
print("Please provide a model using `--model` when using `trt` backend.")
sys.exit(1)
if args.model is None: if args.model is None:
try: try:
response = requests.get(model_url) response = requests.get(model_url)
...@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace): ...@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace):
else: else:
raise ValueError(f"Unknown dataset: {args.dataset_name}") raise ValueError(f"Unknown dataset: {args.dataset_name}")
asyncio.run( if args.multi:
benchmark( request_rates = parse_request_rate_range(args.request_rate_range)
backend=backend,
api_url=api_url, for rate in request_rates:
model_id=model_id, asyncio.run(
tokenizer=tokenizer, benchmark(
input_requests=input_requests, backend=backend,
request_rate=args.request_rate, api_url=api_url,
disable_tqdm=args.disable_tqdm, model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
)
)
else:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
enable_multi=args.multi,
)
) )
)
# to avoid relying on SGLang's components # to avoid relying on SGLang's components
...@@ -751,6 +885,18 @@ if __name__ == "__main__": ...@@ -751,6 +885,18 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Specify to disable tqdm progress bar.", help="Specify to disable tqdm progress bar.",
) )
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
set_ulimit() set_ulimit()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment