Unverified Commit 175afed3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve benchmark scripts (#1672)

parent 4a292f67
......@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
Usage:
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
"""
import argparse
......@@ -15,7 +17,7 @@ import json
import multiprocessing
import os
import time
from typing import Tuple
from typing import Optional, Tuple
import numpy as np
import requests
......@@ -32,6 +34,8 @@ class BenchArgs:
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -48,6 +52,8 @@ class BenchArgs:
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -139,17 +145,21 @@ def run_one_case(
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
proc, base_url = launch_server_process(server_args)
if bench_args.base_url:
proc, base_url = None, bench_args.base_url
else:
proc, base_url = launch_server_process(server_args)
# warmup
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
run_name="",
result_filename="",
)
if not bench_args.skip_warmup:
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
run_name="",
result_filename="",
)
# benchmark
try:
......@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
bench_args.result_filename,
)
finally:
kill_child_process(proc.pid)
if proc:
kill_child_process(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}")
......
......@@ -222,6 +222,85 @@ async def async_request_openai_completions(
return output
async def async_request_sglang_generate(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
prompt = request_func_input.prompt
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = {}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
# print(chunk_bytes)
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_gserver(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
......@@ -264,7 +343,9 @@ def get_tokenizer(
ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions,
"sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
......@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
return filtered_dataset
......@@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
if args.port is None:
args.port = {
"sglang": 30000,
"sglang-native": 30000,
"sglang-oai": 30000,
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
"gserver": 9988,
}.get(args.backend, 30000)
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
model_url = (
f"{args.base_url}/v1/models"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models"
)
if args.backend == "trt":
if args.backend in ["sglang", "sglang-native"]:
api_url = (
f"{args.base_url}/generate"
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment