bench_one_batch_server.py 14 KB
Newer Older
1
"""
2
3
Benchmark the latency of running a single batch with a server.

4
This script launches a server and uses the HTTP interface.
5
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
7

Usage:
8
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9

10
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
11
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
12
13
14
15
16
17
18
"""

import argparse
import dataclasses
import itertools
import json
import multiprocessing
19
import os
20
import time
21
from typing import Tuple
22
23
24

import requests

25
from sglang.bench_serving import get_tokenizer, sample_random_requests
26
from sglang.profiler import run_profile
27
from sglang.srt.entrypoints.http_server import launch_server
28
from sglang.srt.server_args import ServerArgs
29
from sglang.srt.utils import is_blackwell, kill_process_tree
30
from sglang.test.test_utils import is_in_ci, write_github_step_summary
31
32
33
34
35
36
37
38


@dataclasses.dataclass
class BenchArgs:
    run_name: str = "default"
    batch_size: Tuple[int] = (1,)
    input_len: Tuple[int] = (1024,)
    output_len: Tuple[int] = (16,)
39
40
    temperature: float = 0.0
    return_logprob: bool = False
41
    client_stream_interval: int = 1
42
    input_len_step_percentage: float = 0.0
43
    result_filename: str = "result.jsonl"
44
45
    base_url: str = ""
    skip_warmup: bool = False
46
    show_report: bool = False
47
48
    profile: bool = False
    profile_by_stage: bool = False
49
50
51
52
53
54
55
56
57
58
59
60
61

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
62
63
        parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
        parser.add_argument("--return-logprob", action="store_true")
64
65
66
67
68
        parser.add_argument(
            "--client-stream-interval",
            type=int,
            default=BenchArgs.client_stream_interval,
        )
69
70
71
72
73
        parser.add_argument(
            "--input-len-step-percentage",
            type=float,
            default=BenchArgs.input_len_step_percentage,
        )
74
75
76
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
77
78
        parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
        parser.add_argument("--skip-warmup", action="store_true")
79
        parser.add_argument("--show-report", action="store_true")
80
81
        parser.add_argument("--profile", action="store_true")
        parser.add_argument("--profile-by-stage", action="store_true")
82
83
84

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
fzyzcjy's avatar
fzyzcjy committed
85
        # use the default value's type to cast the args into correct types.
86
87
88
89
90
91
92
93
94
95
96
97
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )


def launch_server_internal(server_args):
    try:
        launch_server(server_args)
    except Exception as e:
        raise e
    finally:
98
        kill_process_tree(os.getpid(), include_parent=False)
99
100
101
102
103
104
105
106


def launch_server_process(server_args: ServerArgs):
    proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
    proc.start()
    base_url = f"http://{server_args.host}:{server_args.port}"
    timeout = 600

107
108
    start_time = time.time()
    while time.time() - start_time < timeout:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        try:
            headers = {
                "Content-Type": "application/json; charset=utf-8",
            }
            response = requests.get(f"{base_url}/v1/models", headers=headers)
            if response.status_code == 200:
                return proc, base_url
        except requests.RequestException:
            pass
        time.sleep(10)
    raise TimeoutError("Server failed to start within the timeout period.")


def run_one_case(
    url: str,
    batch_size: int,
    input_len: int,
    output_len: int,
127
128
    temperature: float,
    return_logprob: bool,
129
    stream_interval: int,
130
    input_len_step_percentage: float,
131
132
    run_name: str,
    result_filename: str,
133
    tokenizer,
134
135
    profile: bool = False,
    profile_by_stage: bool = False,
136
):
137
    requests.post(url + "/flush_cache")
138
139
140
141
142
143
144
145
146
147
    input_requests = sample_random_requests(
        input_len=input_len,
        output_len=output_len,
        num_prompts=batch_size,
        range_ratio=1.0,
        tokenizer=tokenizer,
        dataset_path="",
        random_sample=True,
        return_text=False,
    )
148

149
150
151
152
153
154
155
156
157
158
159
160
161
    use_structured_outputs = False
    if use_structured_outputs:
        texts = []
        for _ in range(batch_size):
            texts.append(
                "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
                * 50
                + "Assistant:"
            )
        json_schema = "$$ANY$$"
    else:
        json_schema = None

162
163
164
165
166
167
    profile_link = None
    if profile:
        profile_link: str = run_profile(
            url, 3, ["CPU", "GPU"], None, None, profile_by_stage
        )

168
    tic = time.perf_counter()
169
170
171
    response = requests.post(
        url + "/generate",
        json={
fzyzcjy's avatar
fzyzcjy committed
172
            "input_ids": [req.prompt for req in input_requests],
173
            "sampling_params": {
174
                "temperature": temperature,
175
176
                "max_new_tokens": output_len,
                "ignore_eos": True,
177
                "json_schema": json_schema,
178
                "stream_interval": stream_interval,
179
            },
180
181
            "return_logprob": return_logprob,
            "stream": True,
182
        },
183
        stream=True,
184
185
    )

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    # The TTFT of the last request in the batch
    ttft = 0.0
    for chunk in response.iter_lines(decode_unicode=False):
        chunk = chunk.decode("utf-8")
        if chunk and chunk.startswith("data:"):
            if chunk == "data: [DONE]":
                break
            data = json.loads(chunk[5:].strip("\n"))
            if "error" in data:
                raise RuntimeError(f"Request has failed. {data}.")

            assert (
                data["meta_info"]["finish_reason"] is None
                or data["meta_info"]["finish_reason"]["type"] == "length"
            )
            if data["meta_info"]["completion_tokens"] == 1:
202
                ttft = time.perf_counter() - tic
203

204
    latency = time.perf_counter() - tic
205
206
    input_throughput = batch_size * input_len / ttft
    output_throughput = batch_size * output_len / (latency - ttft)
207
208
    overall_throughput = batch_size * (input_len + output_len) / latency

209
210
211
212
    server_info = requests.get(url + "/get_server_info").json()
    acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
    last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]

213
    print(f"batch size: {batch_size}")
214
215
    print(f"input_len: {input_len}")
    print(f"output_len: {output_len}")
216
    print(f"latency: {latency:.2f} s")
217
    print(f"ttft: {ttft:.2f} s")
218
219
    print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
    print(f"input throughput: {input_throughput:.2f} tok/s")
220
221
    if output_len != 1:
        print(f"output throughput: {output_throughput:.2f} tok/s")
222
223
224
225
226
227
228
229
230
231
232

    if result_filename:
        with open(result_filename, "a") as fout:
            res = {
                "run_name": run_name,
                "batch_size": batch_size,
                "input_len": input_len,
                "output_len": output_len,
                "latency": round(latency, 4),
                "output_throughput": round(output_throughput, 2),
                "overall_throughput": round(overall_throughput, 2),
233
                "last_gen_throughput": round(last_gen_throughput, 2),
234
235
236
            }
            fout.write(json.dumps(res) + "\n")

237
238
239
240
241
242
243
244
245
    return (
        batch_size,
        latency,
        ttft,
        input_throughput,
        output_throughput,
        overall_throughput,
        last_gen_throughput,
        acc_length,
246
        profile_link if profile else None,
247
248
    )

249
250

def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
251
252
253
254
    if bench_args.base_url:
        proc, base_url = None, bench_args.base_url
    else:
        proc, base_url = launch_server_process(server_args)
255

256
257
258
259
260
    server_info = requests.get(base_url + "/get_server_info").json()
    if "tokenizer_path" in server_info:
        tokenizer_path = server_info["tokenizer_path"]
    elif "prefill" in server_info:
        tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
261
    tokenizer = get_tokenizer(tokenizer_path)
262

263
    # warmup
264
    if not bench_args.skip_warmup:
265
        print("=" * 8 + " Warmup Begin " + "=" * 8)
266
267
268
269
270
        run_one_case(
            base_url,
            batch_size=16,
            input_len=1024,
            output_len=16,
271
272
            temperature=bench_args.temperature,
            return_logprob=bench_args.return_logprob,
273
            stream_interval=bench_args.client_stream_interval,
274
            input_len_step_percentage=bench_args.input_len_step_percentage,
275
276
            run_name="",
            result_filename="",
277
            tokenizer=tokenizer,
278
        )
279
        print("=" * 8 + " Warmup End   " + "=" * 8 + "\n")
280
281

    # benchmark
282
    result = []
283
    bench_result = []
284
285
286
287
    try:
        for bs, il, ol in itertools.product(
            bench_args.batch_size, bench_args.input_len, bench_args.output_len
        ):
288
289
290
291
292
293
294
295
            result.append(
                run_one_case(
                    base_url,
                    bs,
                    il,
                    ol,
                    temperature=bench_args.temperature,
                    return_logprob=bench_args.return_logprob,
296
                    stream_interval=bench_args.client_stream_interval,
297
298
299
                    input_len_step_percentage=bench_args.input_len_step_percentage,
                    run_name=bench_args.run_name,
                    result_filename=bench_args.result_filename,
fzyzcjy's avatar
fzyzcjy committed
300
                    tokenizer=tokenizer,
301
                )
302
            )
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

        if bench_args.profile:
            try:
                for bs, il, ol in itertools.product(
                    bench_args.batch_size, bench_args.input_len, bench_args.output_len
                ):
                    bench_result.append(
                        (
                            run_one_case(
                                base_url,
                                bs,
                                il,
                                ol,
                                temperature=bench_args.temperature,
                                return_logprob=bench_args.return_logprob,
318
                                stream_interval=bench_args.client_stream_interval,
319
320
321
322
323
324
325
326
327
328
329
330
                                input_len_step_percentage=bench_args.input_len_step_percentage,
                                run_name=bench_args.run_name,
                                result_filename=bench_args.result_filename,
                                tokenizer=tokenizer,
                                profile=bench_args.profile,
                                profile_by_stage=bench_args.profile_by_stage,
                            )[-1],
                        )
                    )
                result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
            except Exception as e:
                print(f"Error profiling, there will be no profile trace dump: {e}")
331
    finally:
332
        if proc:
333
            kill_process_tree(proc.pid)
334
335
336

    print(f"\nResults are saved to {bench_args.result_filename}")

337
338
339
    if not bench_args.show_report:
        return

340
341
342
343
344
345
346
347
348
349
350
351
352
353
    summary = (
        f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
    )
    summary += "| batch size | latency (s) | input throughput (tok/s)  | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"

    if bench_args.profile:
        summary += " profile |"

    summary += "\n"
    summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"

    if bench_args.profile:
        summary += "-------------|"
    summary += "\n"
354
355
356
357
358
359
360
361
362
363

    for (
        batch_size,
        latency,
        ttft,
        input_throughput,
        output_throughput,
        overall_throughput,
        last_gen_throughput,
        acc_length,
364
        trace_link,
365
    ) in result:
366
367
368
369
370
371
        if is_blackwell():
            hourly_cost_per_gpu = 4  # $4/hour for one B200
        else:
            hourly_cost_per_gpu = 2  # $2/hour for one H100

        hourly_cost = hourly_cost_per_gpu * server_args.tp_size
372
373
374
375
376
377
378
379
380
381
        input_util = 0.7
        accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
        line = (
            f"| {batch_size} | "
            f"{latency:.2f} | "
            f"{input_throughput:.2f} | "
            f"{output_throughput:.2f} | "
            f"{accept_length} | "
            f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
            f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
382
            f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
383
        )
384
385
386
        if trace_link:
            line += f" [Profile]({trace_link}) |"
        line += "\n"
387
388
389
390
391
392
        summary += line

    # print metrics table
    print(summary)

    if is_in_ci():
393
        write_github_step_summary(summary)
394

395
396
397
398
399
400
401
402
403
404

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    run_benchmark(server_args, bench_args)