bench_latency.py 17.5 KB
Newer Older
1
"""
2
3
4
Benchmark the latency of running a single static batch.
This script does not launch a server and uses the low-level APIs.
It accepts arguments similar to those of launch_server.py.
5

6
7
# Usage (latency test)
## with dummy weights:
8
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
9
10
11
12
13
14
15
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"

16
# Usage (correctness test):
17
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
18

19
## Reference output (of the correctness test above, can be gpu dependent):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]

prefill logits (first half): tensor([[-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [ -9.1875, -10.2500,   2.7129,  ...,  -4.3359,  -4.0664,  -4.1328]],
       device='cuda:0')

prefill logits (final): tensor([[-8.3125, -7.1172,  3.3457,  ..., -4.9570, -4.1328, -3.4141],
        [-8.9141, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0781],
        [-9.6328, -9.0547,  4.0195,  ..., -5.3047, -4.7148, -4.4570]],
       device='cuda:0')

========== Prompt 0 ==========
<s> The capital of France is Paris.
34
35
The capital of the United States is Washington, D.C.

36
37
38

========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
39
40
The capital of the United Kingdom is London.
The capital of the
41
42
43

========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
44
45
46
47
48
I'm going to the park
"""

import argparse
import dataclasses
49
import itertools
50
import json
51
52
import logging
import multiprocessing
53
54
import os
import sqlite3
55
import time
56
from typing import Tuple
57
58

import numpy as np
59
import pandas as pd
60
61
62
import torch
import torch.distributed as dist

63
from sglang.srt.configs.model_config import ModelConfig
64
from sglang.srt.hf_transformers_utils import get_tokenizer
65
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
66
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
67
from sglang.srt.model_executor.model_runner import ModelRunner
68
from sglang.srt.sampling.sampling_params import SamplingParams
69
from sglang.srt.server import _set_envs_and_config
70
from sglang.srt.server_args import PortArgs, ServerArgs
71
72
73
74
75
from sglang.srt.utils import (
    configure_logger,
    kill_child_process,
    suppress_other_loggers,
)
76
77
78
79


@dataclasses.dataclass
class BenchArgs:
80
    run_name: str = "before"
81
    batch_size: Tuple[int] = (1,)
82
    input_len: Tuple[int] = (1024,)
83
    output_len: Tuple[int] = (16,)
84
    result_filename: str = ""
85
86
87
    correctness_test: bool = False
    # This is only used for correctness test
    cut_len: int = 4
88
89
90
91
92
    # Plotting args
    graph_sql: str = (
        "select run_name, batch_size, prefill_throughput from results where run_name='before'"
    )
    graph_filename: str = "out.png"
93
94
95

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
96
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
97
98
99
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
100
101
102
103
104
105
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
106
107
108
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
109
110
        parser.add_argument("--correctness-test", action="store_true")
        parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
111
112
113
114
115
        # graphing
        parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
        parser.add_argument(
            "--graph-filename", type=str, default=BenchArgs.graph_filename
        )
116
117
118

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
119
120
121
122
123
        # use the default value's type to case the args into correct types.
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )
124
125


126
def load_model(server_args, port_args, tp_rank):
127
    suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
128
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
129

130
131
132
133
    model_config = ModelConfig(
        server_args.model_path,
        server_args.trust_remote_code,
        context_length=server_args.context_length,
134
        model_override_args=json.loads(server_args.json_model_override_args),
135
    )
136
137
138
139
140
141
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=server_args.mem_fraction_static,
        gpu_id=tp_rank,
        tp_rank=tp_rank,
        tp_size=server_args.tp_size,
142
        nccl_port=port_args.nccl_port,
143
144
        server_args=server_args,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
146
147
148
149
150
151
152
153
154
155
    tokenizer = get_tokenizer(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )
    if server_args.tp_size > 1:
        dist.barrier()
    return model_runner, tokenizer


156
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
    input_ids = [tokenizer.encode(p) for p in prompts]
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(prompts)):
        assert len(input_ids[i]) > bench_args.cut_len

Ying Sheng's avatar
Ying Sheng committed
172
        tmp_input_ids = input_ids[i][: bench_args.cut_len]
173
174
175
176
177
178
        req = Req(
            rid=i,
            origin_input_text=prompts[i],
            origin_input_ids=tmp_input_ids,
            sampling_params=sampling_params,
        )
179
        req.prefix_indices = []
180
        req.fill_ids = req.origin_input_ids
181
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
182
183
184
185
186
        reqs.append(req)

    return input_ids, reqs


187
188
189
def prepare_extend_inputs_for_correctness_test(
    bench_args, input_ids, reqs, model_runner
):
190
191
    for i in range(len(reqs)):
        req = reqs[i]
192
        req.fill_ids += input_ids[i][bench_args.cut_len :]
193
        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
Ying Sheng's avatar
Ying Sheng committed
194
            i, : bench_args.cut_len
195
        ]
196
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
197
198
199
    return reqs


200
201
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
    input_ids = np.ones((batch_size, input_len), dtype=np.int32)
202
203
204
205
206
207
208
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(input_ids)):
209
210
211
212
213
214
        req = Req(
            rid=i,
            origin_input_text="",
            origin_input_ids=list(input_ids[i]),
            sampling_params=sampling_params,
        )
215
        req.prefix_indices = []
216
        req.fill_ids = req.origin_input_ids
217
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
218
219
220
221
222
        reqs.append(req)

    return reqs


223
@torch.inference_mode()
224
def extend(reqs, model_runner):
225
    batch = ScheduleBatch.init_new(
226
227
228
        reqs=reqs,
        req_to_token_pool=model_runner.req_to_token_pool,
        token_to_kv_pool=model_runner.token_to_kv_pool,
Ying Sheng's avatar
Ying Sheng committed
229
230
        tree_cache=None,
    )
231
    batch.prepare_for_extend(model_runner.model_config.vocab_size)
232
233
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
234
    logits_output = model_runner.forward(forward_batch)
235
    next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
Liangsheng Yin's avatar
Liangsheng Yin committed
236
    return next_token_ids, logits_output.next_token_logits, batch
237
238


239
@torch.inference_mode()
240
def decode(input_token_ids, batch, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
241
    batch.prepare_for_decode(input_token_ids)
242
243
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
244
    logits_output = model_runner.forward(forward_batch)
245
    next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
Liangsheng Yin's avatar
Liangsheng Yin committed
246
    return next_token_ids, logits_output.next_token_logits
247
248
249
250


def correctness_test(
    server_args,
251
    port_args,
252
253
254
255
256
257
    bench_args,
    tp_rank,
):
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
258
    model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
259
260

    # Prepare inputs
261
    input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
262
    rank_print(f"\n{input_ids=}\n")
263

Ying Sheng's avatar
Ying Sheng committed
264
265
266
    if bench_args.cut_len > 0:
        # Prefill
        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
267
        rank_print(f"prefill logits (first half): {next_token_logits} \n")
268
269

    # Prepare extend inputs
270
271
272
    reqs = prepare_extend_inputs_for_correctness_test(
        bench_args, input_ids, reqs, model_runner
    )
273
274
275

    # Extend
    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
276
    rank_print(f"prefill logits (final): {next_token_logits} \n")
277
278

    # Decode
Ying Sheng's avatar
Ying Sheng committed
279
    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
280
    for _ in range(bench_args.output_len[0] - 1):
281
282
283
284
285
286
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        for i in range(len(reqs)):
            output_ids[i].append(next_token_ids[i])

    # Print
    for i in range(len(reqs)):
287
288
        rank_print(f"========== Prompt {i} ==========")
        rank_print(tokenizer.decode(output_ids[i]), "\n")
289
290


291
def latency_test_run_once(
292
    run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
293
):
294
295
296
297
298
299
    max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
    if batch_size > max_batch_size:
        rank_print(
            f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
        )
        return
300
301
302
303
304
305

    # Clear the pools.
    model_runner.req_to_token_pool.clear()
    model_runner.token_to_kv_pool.clear()

    measurement_results = {
306
        "run_name": run_name,
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        "batch_size": batch_size,
        "input_len": input_len,
        "output_len": output_len,
    }

    tot_latency = 0

    # Prefill
    torch.cuda.synchronize()
    tic = time.time()
    next_token_ids, _, batch = extend(reqs, model_runner)
    torch.cuda.synchronize()
    prefill_latency = time.time() - tic
    tot_latency += prefill_latency
    throughput = input_len * batch_size / prefill_latency
    rank_print(
        f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["prefill_latency"] = prefill_latency
    measurement_results["prefill_throughput"] = throughput

    # Decode
329
    decode_latencies = []
330
    for i in range(output_len - 1):
331
332
333
334
335
336
337
        torch.cuda.synchronize()
        tic = time.time()
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        torch.cuda.synchronize()
        latency = time.time() - tic
        tot_latency += latency
        throughput = batch_size / latency
338
        decode_latencies.append(latency)
339
340
341
342
        if i < 5:
            rank_print(
                f"Decode.  latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
            )
343
344
345
346
347
348
349
350
351
352

    # record decode timing from 2nd output
    if output_len > 1:
        med_decode_latency = np.median(decode_latencies)
        med_decode_throughput = batch_size / med_decode_latency
        rank_print(
            f"Decode.  median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
        )
        measurement_results["median_decode_latency"] = med_decode_latency
        measurement_results["median_decode_throughput"] = med_decode_throughput
353
354
355
356
357
358
359
360
361
362

    throughput = (input_len + output_len) * batch_size / tot_latency
    rank_print(
        f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["total_latency"] = tot_latency
    measurement_results["total_throughput"] = throughput
    return measurement_results


363
364
def latency_test(
    server_args,
365
    port_args,
366
367
368
    bench_args,
    tp_rank,
):
369
    configure_logger(server_args, prefix=f" TP{tp_rank}")
370
371
372
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
373
    model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
374

375
    # Prepare inputs for warm up
376
    reqs = prepare_synthetic_inputs_for_latency_test(
377
        bench_args.batch_size[0], bench_args.input_len[0]
378
    )
379
380

    # Warm up
Mingyi's avatar
Mingyi committed
381
    rank_print("Warmup ...")
382
    latency_test_run_once(
383
384
385
386
387
388
        bench_args.run_name,
        model_runner,
        rank_print,
        reqs,
        bench_args.batch_size[0],
        bench_args.input_len[0],
389
        8,  # shorter decoding to speed up the warmup
390
    )
Mingyi's avatar
Mingyi committed
391
    rank_print("Benchmark ...")
392

393
    # Run the sweep
394
    result_list = []
395
396
397
    for bs, il, ol in itertools.product(
        bench_args.batch_size, bench_args.input_len, bench_args.output_len
    ):
Ying Sheng's avatar
Ying Sheng committed
398
        reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
399
400
        ret = latency_test_run_once(
            bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
401
        )
402
403
404
405
406
407
        if ret is not None:
            result_list.append(ret)

    # Write results in jsonlines format on rank 0.
    if tp_rank == 0 and bench_args.result_filename:
        import jsonlines
408
409
410

        with jsonlines.open(bench_args.result_filename, "a") as f:
            f.write_all(result_list)
411
412


413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def plot_latency_test(
    server_args,
    bench_args,
    tp_rank,
):
    assert tp_rank == 0

    # read the jsonl file and put in sqlite
    df = pd.read_json(bench_args.result_filename, lines=True)
    conn = sqlite3.connect(":memory:")
    cur = conn.cursor()

    # get the columns and their types
    column_names = list(df.iloc[0].keys())
    type_dict = {
        str: "TEXT",
        np.int64: "INTEGER",
        np.float64: "FLOAT",
    }
    column_types = [type_dict[type(i)] for i in list(df.iloc[0])]

    # create the table
    cur.execute(
        f"""
        CREATE TABLE IF NOT EXISTS results (
            {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
        )
    """
    )
    conn.commit()

    # write the results to DB
    df.to_sql("results", conn, if_exists="replace", index=False)
    conn.commit()

    # read it back using sql
    df = pd.read_sql_query(bench_args.graph_sql, conn)
    conn.close()

    # plot it and save to a file
    import matplotlib.pyplot as plt

    assert (
        len(df.columns) == 3
    ), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
    for label in df[df.columns[0]].unique():
        q = f"{df.columns[0]}=='{label}'"
        series = df.query(q)
        plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
    plt.xlabel(df.columns[1])
    plt.ylabel(df.columns[2])
    plt.legend()
    plt.savefig(bench_args.graph_filename, dpi=300)

    # if in kitty, just dump it to the terminal
    if os.environ["TERM"] == "xterm-kitty":
        os.system(
            f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
        )


474
def main(server_args, bench_args):
475
    _set_envs_and_config(server_args)
476

477
478
479
480
481
482
483
484
    if server_args.model_path:
        if bench_args.correctness_test:
            work_func = correctness_test
        else:
            work_func = latency_test
    elif os.path.isfile(bench_args.result_filename):
        assert bench_args.graph_filename, "please provide a filename for the graph"
        work_func = plot_latency_test
485
    else:
486
487
488
489
        raise ValueError(
            "Provide --model-path for running the tests or "
            "provide --result-filename for plotting the results"
        )
490

491
492
    port_args = PortArgs.init_new(server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
493
    if server_args.tp_size == 1:
494
        work_func(server_args, port_args, bench_args, 0)
Lianmin Zheng's avatar
Lianmin Zheng committed
495
496
497
498
499
500
501
    else:
        workers = []
        for tp_rank in range(server_args.tp_size):
            proc = multiprocessing.Process(
                target=work_func,
                args=(
                    server_args,
502
                    port_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505
506
507
508
509
510
511
512
513
                    bench_args,
                    tp_rank,
                ),
            )
            proc.start()
            workers.append(proc)

        for proc in workers:
            proc.join()

        proc.terminate()
Lianmin Zheng's avatar
Lianmin Zheng committed
514

515
516
517
518
519
520
521
522
523
524
525
526
527
528

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

529
530
531
532
533
534
    try:
        main(server_args, bench_args)
    except Exception as e:
        raise e
    finally:
        kill_child_process(os.getpid(), including_parent=False)