Unverified Commit dfec7fca authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename sglang.bench_latency to sglang.bench_one_batch (#2118)

parent 8048c28c
......@@ -118,7 +118,7 @@ jobs:
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_latency.TestBenchLatency.test_default
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default
- name: Benchmark online latency
timeout-minutes: 10
......@@ -194,7 +194,7 @@ jobs:
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_latency.TestBenchLatency.test_moe_default
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
accuracy-test-1-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
# Benchmark and Profiling
## Benchmark
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`.
- Benchmark the latency of running a single static batch without a server. The arguments are the same as for `launch_server.py`.
Note that this is a simplified test script without a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this simplified script does not.
```
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32
```
- Benchmark online serving. Launch a server first and run the following command.
- Benchmark offline processing. This script will start an offline engine and run the benchmark.
```
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
```
- Benchmark online serving. Please use `sglang.launch_server` to launch a server first and run the following command.
```
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
```
......@@ -23,7 +28,7 @@ apt update
apt install nsight-systems-cli
```
1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
2. To profile a server, e.g.
......@@ -33,7 +38,7 @@ apt install nsight-systems-cli
nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
# client
python3 -m sglang.bench_serving --backend sglang --num-prompts 6000 --dataset-name random --random-input 4096 --random-output 2048
python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512
```
3. Use NVTX, e.g.
......
......@@ -59,7 +59,7 @@ For interactive debugging, you can compare the outputs of huggingface/transforme
The following two commands should give the same text output and very similar prefill logits.
- Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]`
- Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]`
- Get the SGLang output by `python3 -m sglang.bench_one_batch --correct --model [new model]`
#### Add the model to the test suite
To make sure the new model is well maintained in the future, it is better to add it to the test suite.
......
......@@ -59,7 +59,7 @@ drun -p 30000:30000 \
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000
# Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default
drun v0.3.5.post2-rocm620 python3 -m sglang.bench_latency --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
drun v0.3.5.post2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
```
## Method 4: Using docker compose
......
......@@ -16,10 +16,13 @@ classifiers = [
dependencies = ["requests", "tqdm", "numpy", "IPython"]
[project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
"outlines>=0.0.44,<0.1.0", "modelscope"]
runtime_common = ["aiohttp", "decord", "fastapi",
"hf_transfer", "huggingface_hub", "interegular",
"orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
"modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
......
......@@ -4,9 +4,11 @@
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities.
- `api.py`: The public APIs.
- `bench_latency.py`: Benchmark the latency of running a single static batch.
- `bench_server_latency.py`: Benchmark the latency of serving a single batch with a real server.
- `bench_offline_throughput.py`: Benchmark the throughput in the offline mode.
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `utils.py`: Common utilities.
"""
Benchmark the latency of running a single static batch.
This script does not launch a server and uses the low-level APIs.
It accepts arguments similar to those of launch_server.py.
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
device='cuda:0')
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
device='cuda:0')
========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the
========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""
import argparse
import dataclasses
import itertools
import json
import logging
import multiprocessing
import os
import sqlite3
import time
from typing import Tuple
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_child_process,
suppress_other_loggers,
)
@dataclasses.dataclass
class BenchArgs:
run_name: str = "before"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
result_filename: str = ""
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
# Plotting args
graph_sql: str = (
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
)
graph_filename: str = "out.png"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
# graphing
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
parser.add_argument(
"--graph-filename", type=str, default=BenchArgs.graph_filename
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to case the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, port_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.tp_size > 1:
dist.barrier()
return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return input_ids, reqs
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.fill_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
return reqs
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(input_ids)):
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return reqs
@torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
model_config=model_runner.model_config,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch
@torch.no_grad
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits
def correctness_test(
server_args,
port_args,
bench_args,
tp_rank,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (first half): {next_token_logits} \n")
# Prepare extend inputs
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (final): {next_token_logits} \n")
# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)):
output_ids[i].append(next_token_ids_list[i])
# Print
for i in range(len(reqs)):
rank_print(f"========== Prompt {i} ==========")
rank_print(tokenizer.decode(output_ids[i]), "\n")
def synchronize(device):
if device == "cuda":
torch.cuda.synchronize()
elif device == "xpu":
torch.xpu.synchronize()
def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
measurement_results = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
}
tot_latency = 0
# Prefill
synchronize(device)
tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode
decode_latencies = []
for i in range(output_len - 1):
synchronize(device)
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.time() - tic
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
# record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results
def latency_test(
server_args,
port_args,
bench_args,
tp_rank,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size[0], bench_args.input_len[0]
)
# Warm up
rank_print("Warmup ...")
latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
8, # shorter decoding to speed up the warmup
server_args.device,
)
rank_print("Benchmark ...")
# Run the sweep
result_list = []
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
ret = latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bs,
il,
ol,
server_args.device,
)
if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
import jsonlines
with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list)
def plot_latency_test(
server_args,
bench_args,
tp_rank,
):
assert tp_rank == 0
# read the jsonl file and put in sqlite
df = pd.read_json(bench_args.result_filename, lines=True)
conn = sqlite3.connect(":memory:")
cur = conn.cursor()
# get the columns and their types
column_names = list(df.iloc[0].keys())
type_dict = {
str: "TEXT",
np.int64: "INTEGER",
np.float64: "FLOAT",
}
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
# create the table
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS results (
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
)
"""
)
conn.commit()
# write the results to DB
df.to_sql("results", conn, if_exists="replace", index=False)
conn.commit()
# read it back using sql
df = pd.read_sql_query(bench_args.graph_sql, conn)
conn.close()
# plot it and save to a file
import matplotlib.pyplot as plt
assert (
len(df.columns) == 3
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
for label in df[df.columns[0]].unique():
q = f"{df.columns[0]}=='{label}'"
series = df.query(q)
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
plt.xlabel(df.columns[1])
plt.ylabel(df.columns[2])
plt.legend()
plt.savefig(bench_args.graph_filename, dpi=300)
# if in kitty, just dump it to the terminal
if os.environ["TERM"] == "xterm-kitty":
os.system(
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
)
def main(server_args, bench_args):
_set_envs_and_config(server_args)
if server_args.model_path:
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
elif os.path.isfile(bench_args.result_filename):
assert bench_args.graph_filename, "please provide a filename for the graph"
work_func = plot_latency_test
else:
raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1:
work_func(server_args, port_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
port_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
main(server_args, bench_args)
except Exception as e:
raise e
finally:
kill_child_process()
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
"""
Benchmark the throughput of using the offline LLM engine.
This script does not launch a server.
Benchmark the throughput in the offline mode.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
# Usage
## Sharegpt dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
## Random dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
## Shared prefix dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
## Sharegpt dataset on runtime backend
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
"""
import argparse
......
"""
Benchmark the latency of running a single static batch without a server.
This script does not launch a server and uses the low-level APIs.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
device='cuda:0')
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
device='cuda:0')
========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the
========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""
import argparse
import dataclasses
import itertools
import json
import logging
import multiprocessing
import time
from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_child_process,
suppress_other_loggers,
)
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
result_filename: str = "result.jsonl"
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to case the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, port_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.tp_size > 1:
dist.barrier()
return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return input_ids, reqs
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.fill_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
return reqs
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(input_ids)):
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return reqs
@torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
model_config=model_runner.model_config,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch
@torch.no_grad
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits
def correctness_test(
server_args,
port_args,
bench_args,
tp_rank,
):
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (first half): {next_token_logits} \n")
# Prepare extend inputs
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend (prefill w/ KV cache)
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (final): {next_token_logits} \n")
# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)):
output_ids[i].append(next_token_ids_list[i])
# Print output texts
for i in range(len(reqs)):
rank_print(f"========== Prompt {i} ==========")
rank_print(tokenizer.decode(output_ids[i]), "\n")
def synchronize(device):
if device == "cuda":
torch.cuda.synchronize()
elif device == "xpu":
torch.xpu.synchronize()
def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
measurement_results = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
}
tot_latency = 0
# Prefill
synchronize(device)
tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode
decode_latencies = []
for i in range(output_len - 1):
synchronize(device)
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.time() - tic
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
# Record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["overall_throughput"] = throughput
return measurement_results
def latency_test(
server_args,
port_args,
bench_args,
tp_rank,
):
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size[0], bench_args.input_len[0]
)
# Warm up
rank_print("Warmup ...")
latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
8, # shorter decoding to speed up the warmup
server_args.device,
)
rank_print("Benchmark ...")
# Run the sweep
result_list = []
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
ret = latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bs,
il,
ol,
server_args.device,
)
if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
for result in result_list:
fout.write(json.dumps(result) + "\n")
def main(server_args, bench_args):
_set_envs_and_config(server_args)
if server_args.model_path:
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
else:
raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1:
work_func(server_args, port_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
port_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
main(server_args, bench_args)
except Exception as e:
raise e
finally:
kill_child_process()
"""
Benchmark the latency of serving a single batch with a real server.
Benchmark the latency of running a single batch with a server.
This script launches a server and uses the HTTP interface.
It accepts arguments similar to those of launch_server.py.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
Usage:
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
......
......@@ -15,24 +15,21 @@ PACKAGE_LIST = [
"flashinfer",
"triton",
"transformers",
"requests",
"tqdm",
"torchao",
"numpy",
"aiohttp",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"packaging",
"PIL",
"psutil",
"pydantic",
"multipart",
"zmq",
"uvicorn",
"uvloop",
"zmq",
"vllm",
"outlines",
"multipart",
"openai",
"tiktoken",
"anthropic",
......
......@@ -30,10 +30,10 @@ device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_latency.py`.
An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_latency --correct \
$ python3 -m sglang.bench_one_batch --correct \
--model meta-llama/Meta-Llama-3-8B \
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
--tensor-parallel-size 2 \
......
......@@ -579,11 +579,11 @@ def run_bench_serving(
return res
def run_bench_latency(model, other_args):
def run_bench_one_batch(model, other_args):
command = [
"python3",
"-m",
"sglang.bench_latency",
"sglang.bench_one_batch",
"--model-path",
model,
"--batch-size",
......
......@@ -4,19 +4,19 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci,
run_bench_latency,
run_bench_one_batch,
)
class TestBenchLatency(unittest.TestCase):
class TestBenchOneBatch(unittest.TestCase):
def test_default(self):
output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, [])
output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, [])
if is_in_ci():
self.assertGreater(output_throughput, 135)
def test_moe_default(self):
output_throughput = run_bench_latency(
output_throughput = run_bench_one_batch(
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
)
......
import unittest
from sglang.test.test_utils import is_in_ci, run_bench_latency
from sglang.test.test_utils import is_in_ci, run_bench_one_batch
class TestTorchTP(unittest.TestCase):
def test_torch_native_llama(self):
output_throughput = run_bench_latency(
output_throughput = run_bench_one_batch(
"meta-llama/Meta-Llama-3-8B",
[
"--tp",
......
......@@ -14,13 +14,13 @@ from sglang.test.test_utils import (
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
run_bench_latency,
run_bench_one_batch,
)
class TestTritonAttnBackend(unittest.TestCase):
def test_latency(self):
output_throughput = run_bench_latency(
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
[
"--attention-backend",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment