bench_one_batch_server.py 5.78 KB
Newer Older
1
"""
2
3
Benchmark the latency of running a single batch with a server.

4
This script launches a server and uses the HTTP interface.
5
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
7

Usage:
8
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9

10
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
11
12
13
14
15
16
17
18
"""

import argparse
import dataclasses
import itertools
import json
import multiprocessing
import time
19
from typing import Tuple
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

import numpy as np
import requests

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process


@dataclasses.dataclass
class BenchArgs:
    run_name: str = "default"
    batch_size: Tuple[int] = (1,)
    input_len: Tuple[int] = (1024,)
    output_len: Tuple[int] = (16,)
    result_filename: str = "result.jsonl"
36
37
    base_url: str = ""
    skip_warmup: bool = False
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
54
55
        parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
        parser.add_argument("--skip-warmup", action="store_true")
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        # use the default value's type to case the args into correct types.
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )


def launch_server_internal(server_args):
    try:
        launch_server(server_args)
    except Exception as e:
        raise e
    finally:
Lianmin Zheng's avatar
Lianmin Zheng committed
72
        kill_child_process()
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146


def launch_server_process(server_args: ServerArgs):
    proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
    proc.start()
    base_url = f"http://{server_args.host}:{server_args.port}"
    timeout = 600

    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            headers = {
                "Content-Type": "application/json; charset=utf-8",
            }
            response = requests.get(f"{base_url}/v1/models", headers=headers)
            if response.status_code == 200:
                return proc, base_url
        except requests.RequestException:
            pass
        time.sleep(10)
    raise TimeoutError("Server failed to start within the timeout period.")


def run_one_case(
    url: str,
    batch_size: int,
    input_len: int,
    output_len: int,
    run_name: str,
    result_filename: str,
):
    input_ids = [
        [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
        for _ in range(batch_size)
    ]

    tic = time.time()
    response = requests.post(
        url + "/generate",
        json={
            "input_ids": input_ids,
            "sampling_params": {
                "temperature": 0,
                "max_new_tokens": output_len,
                "ignore_eos": True,
            },
        },
    )
    latency = time.time() - tic

    _ = response.json()
    output_throughput = batch_size * output_len / latency
    overall_throughput = batch_size * (input_len + output_len) / latency

    print(f"batch size: {batch_size}")
    print(f"latency: {latency:.2f} s")
    print(f"output throughput: {output_throughput:.2f} token/s")
    print(f"(input + output) throughput: {overall_throughput:.2f} token/s")

    if result_filename:
        with open(result_filename, "a") as fout:
            res = {
                "run_name": run_name,
                "batch_size": batch_size,
                "input_len": input_len,
                "output_len": output_len,
                "latency": round(latency, 4),
                "output_throughput": round(output_throughput, 2),
                "overall_throughput": round(overall_throughput, 2),
            }
            fout.write(json.dumps(res) + "\n")


def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
147
148
149
150
    if bench_args.base_url:
        proc, base_url = None, bench_args.base_url
    else:
        proc, base_url = launch_server_process(server_args)
151
152

    # warmup
153
154
155
156
157
158
159
160
161
    if not bench_args.skip_warmup:
        run_one_case(
            base_url,
            batch_size=16,
            input_len=1024,
            output_len=16,
            run_name="",
            result_filename="",
        )
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    # benchmark
    try:
        for bs, il, ol in itertools.product(
            bench_args.batch_size, bench_args.input_len, bench_args.output_len
        ):
            run_one_case(
                base_url,
                bs,
                il,
                ol,
                bench_args.run_name,
                bench_args.result_filename,
            )
    finally:
177
        if proc:
Lianmin Zheng's avatar
Lianmin Zheng committed
178
            kill_child_process(proc.pid, include_self=True)
179
180
181
182
183
184
185
186
187
188
189
190
191

    print(f"\nResults are saved to {bench_args.result_filename}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    run_benchmark(server_args, bench_args)