Unverified Commit ebf69964 authored by min-xu-et's avatar min-xu-et Committed by GitHub
Browse files

latency test enhancement - final part (#921)

parent 141e8c71
......@@ -20,14 +20,16 @@ dependencies = [
]
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
......
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test) with dummy weights:
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output (of the correctness test above, can be gpu dependent):
## Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
......@@ -28,13 +36,16 @@ I'm going to the park
import argparse
import dataclasses
import itertools
import logging
import multiprocessing
import os
import sqlite3
import time
from typing import Tuple
import jsonlines
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
......@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers
@dataclasses.dataclass
class BenchArgs:
run_name: str = "before"
batch_size: Tuple[int] = (1,)
input_len: int = 1024
output_len: int = 4
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (4,)
result_filename: str = ""
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
# Plotting args
graph_sql: str = (
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
)
graph_filename: str = "out.png"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
# graphing
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
parser.add_argument(
"--graph-filename", type=str, default=BenchArgs.graph_filename
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -222,15 +249,21 @@ def correctness_test(
@torch.inference_mode()
def latency_test_run_once(
model_runner, rank_print, reqs, batch_size, input_len, output_len
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
measurement_results = {
"run_name": "before",
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
......@@ -291,49 +324,119 @@ def latency_test(
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args.batch_size = bench_args.batch_size[0]
# Prepare inputs
# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size, bench_args.input_len
bench_args.batch_size[0], bench_args.input_len[0]
)
# Warm up
latency_test_run_once(
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
4, # shorter decoding to speed up the warmup
)
# Run again
# Run the sweep
result_list = []
result_list.append(
latency_test_run_once(
model_runner,
rank_print,
reqs,
bench_args.batch_size,
bench_args.input_len,
bench_args.output_len,
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
req = prepare_synthetic_inputs_for_latency_test(bs, il)
ret = latency_test_run_once(
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
)
)
if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
import jsonlines
# Write results in jsonlines format.
if bench_args.result_filename:
with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list)
def plot_latency_test(
server_args,
bench_args,
tp_rank,
):
assert tp_rank == 0
# read the jsonl file and put in sqlite
df = pd.read_json(bench_args.result_filename, lines=True)
conn = sqlite3.connect(":memory:")
cur = conn.cursor()
# get the columns and their types
column_names = list(df.iloc[0].keys())
type_dict = {
str: "TEXT",
np.int64: "INTEGER",
np.float64: "FLOAT",
}
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
# create the table
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS results (
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
)
"""
)
conn.commit()
# write the results to DB
df.to_sql("results", conn, if_exists="replace", index=False)
conn.commit()
# read it back using sql
df = pd.read_sql_query(bench_args.graph_sql, conn)
conn.close()
# plot it and save to a file
import matplotlib.pyplot as plt
assert (
len(df.columns) == 3
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
for label in df[df.columns[0]].unique():
q = f"{df.columns[0]}=='{label}'"
series = df.query(q)
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
plt.xlabel(df.columns[1])
plt.ylabel(df.columns[2])
plt.legend()
plt.savefig(bench_args.graph_filename, dpi=300)
# if in kitty, just dump it to the terminal
if os.environ["TERM"] == "xterm-kitty":
os.system(
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
)
def main(server_args, bench_args):
print(bench_args)
if bench_args.correctness_test:
work_func = correctness_test
if server_args.model_path:
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
elif os.path.isfile(bench_args.result_filename):
assert bench_args.graph_filename, "please provide a filename for the graph"
work_func = plot_latency_test
else:
work_func = latency_test
raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
......@@ -361,6 +464,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
# For this script, model-path is not required
assert (
parser._actions[1].option_strings[0] == "--model-path"
), "options changed, this code need to be updated"
parser._actions[1].required = False
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment