Unverified Commit 53985645 authored by min-xu-et's avatar min-xu-et Committed by GitHub
Browse files

latency test enhancement - part 1 (#909)

parent 70cc0749
...@@ -21,7 +21,7 @@ dependencies = [ ...@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart"] "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
......
""" """
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test): # Usage (latency test) with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
# Usage (correctness test): # Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output: ### Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
...@@ -31,7 +31,9 @@ import dataclasses ...@@ -31,7 +31,9 @@ import dataclasses
import logging import logging
import multiprocessing import multiprocessing
import time import time
from typing import Tuple
import jsonlines
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers ...@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
@dataclasses.dataclass @dataclasses.dataclass
class BenchArgs: class BenchArgs:
batch_size: int = 1 batch_size: Tuple[int] = (1,)
input_len: int = 1024 input_len: int = 1024
output_len: int = 4 output_len: int = 4
result_filename: str = ""
correctness_test: bool = False correctness_test: bool = False
# This is only used for correctness test # This is only used for correctness test
cut_len: int = 4 cut_len: int = 4
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size) parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)] # use the default value's type to case the args into correct types.
return cls(**{attr: getattr(args, attr) for attr in attrs}) attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, tp_rank): def load_model(server_args, tp_rank):
...@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank): ...@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
return model_runner, tokenizer return model_runner, tokenizer
def prepare_inputs(bench_args, tokenizer): def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [ prompts = [
"The capital of France is", "The capital of France is",
"The capital of the United Kindom is", "The capital of the United Kindom is",
...@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer): ...@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
return input_ids, reqs return input_ids, reqs
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)): for i in range(len(reqs)):
req = reqs[i] req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len :] req.input_ids += input_ids[i][bench_args.cut_len :]
...@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): ...@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
return reqs return reqs
def prepare_synthetic_inputs(bench_args, tokenizer): def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32) input_ids = np.ones((batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
max_new_tokens=BenchArgs.output_len, max_new_tokens=BenchArgs.output_len,
...@@ -179,7 +192,7 @@ def correctness_test( ...@@ -179,7 +192,7 @@ def correctness_test(
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, tp_rank)
# Prepare inputs # Prepare inputs
input_ids, reqs = prepare_inputs(bench_args, tokenizer) input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
if bench_args.cut_len > 0: if bench_args.cut_len > 0:
# Prefill # Prefill
...@@ -187,7 +200,9 @@ def correctness_test( ...@@ -187,7 +200,9 @@ def correctness_test(
rank_print("prefill logits (first half)", next_token_logits) rank_print("prefill logits (first half)", next_token_logits)
# Prepare extend inputs # Prepare extend inputs
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner) reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend # Extend
next_token_ids, next_token_logits, batch = extend(reqs, model_runner) next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
...@@ -218,8 +233,13 @@ def latency_test( ...@@ -218,8 +233,13 @@ def latency_test(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
) )
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args.batch_size = bench_args.batch_size[0]
# Prepare inputs # Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer) reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size, bench_args.input_len
)
def clear(): def clear():
model_runner.req_to_token_pool.clear() model_runner.req_to_token_pool.clear()
...@@ -227,6 +247,11 @@ def latency_test( ...@@ -227,6 +247,11 @@ def latency_test(
@torch.inference_mode() @torch.inference_mode()
def run_once(output_len): def run_once(output_len):
measurement_results = {
"batch_size": bench_args.batch_size,
"output_len": output_len,
}
# Prefill # Prefill
torch.cuda.synchronize() torch.cuda.synchronize()
tot_latency = 0 tot_latency = 0
...@@ -239,6 +264,8 @@ def latency_test( ...@@ -239,6 +264,8 @@ def latency_test(
rank_print( rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
) )
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode # Decode
for i in range(output_len): for i in range(output_len):
...@@ -258,6 +285,8 @@ def latency_test( ...@@ -258,6 +285,8 @@ def latency_test(
rank_print( rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
) )
measurement_results["avg_decode_latency"] = avg_decode_latency
measurement_results["avg_decode_throughput"] = avg_decode_throughput
throughput = ( throughput = (
(bench_args.input_len + bench_args.output_len) (bench_args.input_len + bench_args.output_len)
...@@ -267,13 +296,22 @@ def latency_test( ...@@ -267,13 +296,22 @@ def latency_test(
rank_print( rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
) )
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results
# Warm up # Warm up
run_once(4) run_once(4)
clear() clear()
# Run again # Run again
run_once(bench_args.output_len) result_list = []
result_list.append(run_once(bench_args.output_len))
# Write results in jsonlines format.
if bench_args.result_filename:
with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list)
def main(server_args, bench_args): def main(server_args, bench_args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment