Unverified Commit ebf69964 authored by min-xu-et's avatar min-xu-et Committed by GitHub
Browse files

latency test enhancement - final part (#921)

parent 141e8c71
...@@ -20,14 +20,16 @@ dependencies = [ ...@@ -20,14 +20,16 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines", srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq", "torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.3.post1", "outlines>=0.0.44"] "vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/sgl-project/sglang" "Homepage" = "https://github.com/sgl-project/sglang"
......
""" """
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test) with dummy weights: # Usage (latency test)
## with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
# Usage (correctness test): # Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output (of the correctness test above, can be gpu dependent): ## Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
...@@ -28,13 +36,16 @@ I'm going to the park ...@@ -28,13 +36,16 @@ I'm going to the park
import argparse import argparse
import dataclasses import dataclasses
import itertools
import logging import logging
import multiprocessing import multiprocessing
import os
import sqlite3
import time import time
from typing import Tuple from typing import Tuple
import jsonlines
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers ...@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers
@dataclasses.dataclass @dataclasses.dataclass
class BenchArgs: class BenchArgs:
run_name: str = "before"
batch_size: Tuple[int] = (1,) batch_size: Tuple[int] = (1,)
input_len: int = 1024 input_len: Tuple[int] = (1024,)
output_len: int = 4 output_len: Tuple[int] = (4,)
result_filename: str = "" result_filename: str = ""
correctness_test: bool = False correctness_test: bool = False
# This is only used for correctness test # This is only used for correctness test
cut_len: int = 4 cut_len: int = 4
# Plotting args
graph_sql: str = (
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
)
graph_filename: str = "out.png"
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument( parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
) )
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) parser.add_argument(
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) "--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename "--result-filename", type=str, default=BenchArgs.result_filename
) )
parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
# graphing
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
parser.add_argument(
"--graph-filename", type=str, default=BenchArgs.graph_filename
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -222,15 +249,21 @@ def correctness_test( ...@@ -222,15 +249,21 @@ def correctness_test(
@torch.inference_mode() @torch.inference_mode()
def latency_test_run_once( def latency_test_run_once(
model_runner, rank_print, reqs, batch_size, input_len, output_len run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
): ):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools. # Clear the pools.
model_runner.req_to_token_pool.clear() model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear() model_runner.token_to_kv_pool.clear()
measurement_results = { measurement_results = {
"run_name": "before", "run_name": run_name,
"batch_size": batch_size, "batch_size": batch_size,
"input_len": input_len, "input_len": input_len,
"output_len": output_len, "output_len": output_len,
...@@ -291,49 +324,119 @@ def latency_test( ...@@ -291,49 +324,119 @@ def latency_test(
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, tp_rank)
rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple. # Prepare inputs for warm up
bench_args.batch_size = bench_args.batch_size[0]
# Prepare inputs
reqs = prepare_synthetic_inputs_for_latency_test( reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size, bench_args.input_len bench_args.batch_size[0], bench_args.input_len[0]
) )
# Warm up # Warm up
latency_test_run_once( latency_test_run_once(
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4 bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
4, # shorter decoding to speed up the warmup
) )
# Run again # Run the sweep
result_list = [] result_list = []
result_list.append( for bs, il, ol in itertools.product(
latency_test_run_once( bench_args.batch_size, bench_args.input_len, bench_args.output_len
model_runner, ):
rank_print, req = prepare_synthetic_inputs_for_latency_test(bs, il)
reqs, ret = latency_test_run_once(
bench_args.batch_size, bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
bench_args.input_len,
bench_args.output_len,
) )
) if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
import jsonlines
# Write results in jsonlines format.
if bench_args.result_filename:
with jsonlines.open(bench_args.result_filename, "a") as f: with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list) f.write_all(result_list)
def plot_latency_test(
server_args,
bench_args,
tp_rank,
):
assert tp_rank == 0
# read the jsonl file and put in sqlite
df = pd.read_json(bench_args.result_filename, lines=True)
conn = sqlite3.connect(":memory:")
cur = conn.cursor()
# get the columns and their types
column_names = list(df.iloc[0].keys())
type_dict = {
str: "TEXT",
np.int64: "INTEGER",
np.float64: "FLOAT",
}
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
# create the table
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS results (
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
)
"""
)
conn.commit()
# write the results to DB
df.to_sql("results", conn, if_exists="replace", index=False)
conn.commit()
# read it back using sql
df = pd.read_sql_query(bench_args.graph_sql, conn)
conn.close()
# plot it and save to a file
import matplotlib.pyplot as plt
assert (
len(df.columns) == 3
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
for label in df[df.columns[0]].unique():
q = f"{df.columns[0]}=='{label}'"
series = df.query(q)
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
plt.xlabel(df.columns[1])
plt.ylabel(df.columns[2])
plt.legend()
plt.savefig(bench_args.graph_filename, dpi=300)
# if in kitty, just dump it to the terminal
if os.environ["TERM"] == "xterm-kitty":
os.system(
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
)
def main(server_args, bench_args): def main(server_args, bench_args):
print(bench_args)
if bench_args.correctness_test: if server_args.model_path:
work_func = correctness_test if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
elif os.path.isfile(bench_args.result_filename):
assert bench_args.graph_filename, "please provide a filename for the graph"
work_func = plot_latency_test
else: else:
work_func = latency_test raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
if server_args.tp_size == 1: if server_args.tp_size == 1:
work_func(server_args, bench_args, 0) work_func(server_args, bench_args, 0)
...@@ -361,6 +464,11 @@ if __name__ == "__main__": ...@@ -361,6 +464,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser)
# For this script, model-path is not required
assert (
parser._actions[1].option_strings[0] == "--model-path"
), "options changed, this code need to be updated"
parser._actions[1].required = False
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment