Unverified Commit 175afed3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve benchmark scripts (#1672)

parent 4a292f67
...@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py. ...@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
Usage: Usage:
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
""" """
import argparse import argparse
...@@ -15,7 +17,7 @@ import json ...@@ -15,7 +17,7 @@ import json
import multiprocessing import multiprocessing
import os import os
import time import time
from typing import Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import requests import requests
...@@ -32,6 +34,8 @@ class BenchArgs: ...@@ -32,6 +34,8 @@ class BenchArgs:
input_len: Tuple[int] = (1024,) input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,) output_len: Tuple[int] = (16,)
result_filename: str = "result.jsonl" result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -48,6 +52,8 @@ class BenchArgs: ...@@ -48,6 +52,8 @@ class BenchArgs:
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename "--result-filename", type=str, default=BenchArgs.result_filename
) )
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -139,9 +145,13 @@ def run_one_case( ...@@ -139,9 +145,13 @@ def run_one_case(
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url:
proc, base_url = None, bench_args.base_url
else:
proc, base_url = launch_server_process(server_args) proc, base_url = launch_server_process(server_args)
# warmup # warmup
if not bench_args.skip_warmup:
run_one_case( run_one_case(
base_url, base_url,
batch_size=16, batch_size=16,
...@@ -165,6 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -165,6 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
bench_args.result_filename, bench_args.result_filename,
) )
finally: finally:
if proc:
kill_child_process(proc.pid) kill_child_process(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")
......
...@@ -222,6 +222,85 @@ async def async_request_openai_completions( ...@@ -222,6 +222,85 @@ async def async_request_openai_completions(
return output return output
async def async_request_sglang_generate(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
prompt = request_func_input.prompt
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = {}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
# print(chunk_bytes)
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_gserver( async def async_request_gserver(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
...@@ -264,7 +343,9 @@ def get_tokenizer( ...@@ -264,7 +343,9 @@ def get_tokenizer(
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions, "sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions, "lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm, "trt": async_request_trt_llm,
...@@ -387,6 +468,8 @@ def sample_sharegpt_requests( ...@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
return filtered_dataset return filtered_dataset
...@@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
if args.port is None: if args.port is None:
args.port = { args.port = {
"sglang": 30000, "sglang": 30000,
"sglang-native": 30000,
"sglang-oai": 30000,
"lmdeploy": 23333, "lmdeploy": 23333,
"vllm": 8000, "vllm": 8000,
"trt": 8000, "trt": 8000,
"gserver": 9988, "gserver": 9988,
}.get(args.backend, 30000) }.get(args.backend, 30000)
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
model_url = ( model_url = (
f"{args.base_url}/v1/models" f"{args.base_url}/v1/models"
if args.base_url if args.base_url
else f"http://{args.host}:{args.port}/v1/models" else f"http://{args.host}:{args.port}/v1/models"
) )
if args.backend == "trt": if args.backend in ["sglang", "sglang-native"]:
api_url = (
f"{args.base_url}/generate"
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend == "trt":
api_url = ( api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream" f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url if args.base_url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment