Unverified Commit 53985645 authored by min-xu-et's avatar min-xu-et Committed by GitHub
Browse files

latency test enhancement - part 1 (#909)

parent 70cc0749
......@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart"]
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
......
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test):
# Usage (latency test) with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output:
### Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
......@@ -31,7 +31,9 @@ import dataclasses
import logging
import multiprocessing
import time
from typing import Tuple
import jsonlines
import numpy as np
import torch
import torch.distributed as dist
......@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
@dataclasses.dataclass
class BenchArgs:
batch_size: int = 1
batch_size: Tuple[int] = (1,)
input_len: int = 1024
output_len: int = 4
result_filename: str = ""
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
# use the default value's type to case the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, tp_rank):
......@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
return model_runner, tokenizer
def prepare_inputs(bench_args, tokenizer):
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
......@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
return input_ids, reqs
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len :]
......@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
return reqs
def prepare_synthetic_inputs(bench_args, tokenizer):
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
......@@ -179,7 +192,7 @@ def correctness_test(
model_runner, tokenizer = load_model(server_args, tp_rank)
# Prepare inputs
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
if bench_args.cut_len > 0:
# Prefill
......@@ -187,7 +200,9 @@ def correctness_test(
rank_print("prefill logits (first half)", next_token_logits)
# Prepare extend inputs
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
......@@ -218,8 +233,13 @@ def latency_test(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args.batch_size = bench_args.batch_size[0]
# Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size, bench_args.input_len
)
def clear():
model_runner.req_to_token_pool.clear()
......@@ -227,6 +247,11 @@ def latency_test(
@torch.inference_mode()
def run_once(output_len):
measurement_results = {
"batch_size": bench_args.batch_size,
"output_len": output_len,
}
# Prefill
torch.cuda.synchronize()
tot_latency = 0
......@@ -239,6 +264,8 @@ def latency_test(
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode
for i in range(output_len):
......@@ -258,6 +285,8 @@ def latency_test(
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
measurement_results["avg_decode_latency"] = avg_decode_latency
measurement_results["avg_decode_throughput"] = avg_decode_throughput
throughput = (
(bench_args.input_len + bench_args.output_len)
......@@ -267,13 +296,22 @@ def latency_test(
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results
# Warm up
run_once(4)
clear()
# Run again
run_once(bench_args.output_len)
result_list = []
result_list.append(run_once(bench_args.output_len))
# Write results in jsonlines format.
if bench_args.result_filename:
with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list)
def main(server_args, bench_args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment