"tests/test_data/vscode:/vscode.git/clone" did not exist on "bdeacecd9484b1e7d3715440f73c83e1e69bf353"
bench_latency.py 16.7 KB
Newer Older
1
"""
2
3
4
Benchmark the latency of running a single static batch.
This script does not launch a server and uses the low-level APIs.
It accepts arguments similar to those of launch_server.py.
5

6
7
# Usage (latency test)
## with dummy weights:
8
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
9
10
11
12
13
14
15
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
## do some changes, and store the results under a different run_name:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
## plot the results in series of lines:
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"

16
# Usage (correctness test):
17
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
18

19
## Reference output (of the correctness test above, can be gpu dependent):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]

prefill logits (first half): tensor([[-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [ -9.1875, -10.2500,   2.7129,  ...,  -4.3359,  -4.0664,  -4.1328]],
       device='cuda:0')

prefill logits (final): tensor([[-8.3125, -7.1172,  3.3457,  ..., -4.9570, -4.1328, -3.4141],
        [-8.9141, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0781],
        [-9.6328, -9.0547,  4.0195,  ..., -5.3047, -4.7148, -4.4570]],
       device='cuda:0')

========== Prompt 0 ==========
<s> The capital of France is Paris.
34
35
The capital of the United States is Washington, D.C.

36
37
38

========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
39
40
The capital of the United Kingdom is London.
The capital of the
41
42
43

========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
44
45
46
47
48
I'm going to the park
"""

import argparse
import dataclasses
49
import itertools
50
51
import logging
import multiprocessing
52
53
import os
import sqlite3
54
import time
55
from typing import Tuple
56
57

import numpy as np
58
import pandas as pd
59
60
61
import torch
import torch.distributed as dist

62
from sglang.srt.configs.model_config import ModelConfig
63
from sglang.srt.hf_transformers_utils import get_tokenizer
64
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
65
from sglang.srt.model_executor.model_runner import ModelRunner
66
from sglang.srt.sampling.sampling_params import SamplingParams
67
from sglang.srt.server import _set_envs_and_config
68
from sglang.srt.server_args import ServerArgs
69
70
71
72
73
from sglang.srt.utils import (
    configure_logger,
    kill_child_process,
    suppress_other_loggers,
)
74
75
76
77


@dataclasses.dataclass
class BenchArgs:
78
    run_name: str = "before"
79
    batch_size: Tuple[int] = (1,)
80
    input_len: Tuple[int] = (1024,)
81
    output_len: Tuple[int] = (16,)
82
    result_filename: str = ""
83
84
85
    correctness_test: bool = False
    # This is only used for correctness test
    cut_len: int = 4
86
87
88
89
90
    # Plotting args
    graph_sql: str = (
        "select run_name, batch_size, prefill_throughput from results where run_name='before'"
    )
    graph_filename: str = "out.png"
91
92
93

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
94
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
95
96
97
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
98
99
100
101
102
103
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
104
105
106
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
107
108
        parser.add_argument("--correctness-test", action="store_true")
        parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
109
110
111
112
113
        # graphing
        parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
        parser.add_argument(
            "--graph-filename", type=str, default=BenchArgs.graph_filename
        )
114
115
116

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
117
118
119
120
121
        # use the default value's type to case the args into correct types.
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )
122
123
124
125


def load_model(server_args, tp_rank):
    suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
126
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
127

128
129
130
131
132
    model_config = ModelConfig(
        server_args.model_path,
        server_args.trust_remote_code,
        context_length=server_args.context_length,
    )
133
134
135
136
137
138
139
140
141
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=server_args.mem_fraction_static,
        gpu_id=tp_rank,
        tp_rank=tp_rank,
        tp_size=server_args.tp_size,
        nccl_port=28888,
        server_args=server_args,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
142
    rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
143
144
145
146
147
148
149
150
151
152
    tokenizer = get_tokenizer(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )
    if server_args.tp_size > 1:
        dist.barrier()
    return model_runner, tokenizer


153
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
    input_ids = [tokenizer.encode(p) for p in prompts]
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(prompts)):
        assert len(input_ids[i]) > bench_args.cut_len

Ying Sheng's avatar
Ying Sheng committed
169
        tmp_input_ids = input_ids[i][: bench_args.cut_len]
170
171
172
        req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
        req.prefix_indices = []
        req.sampling_params = sampling_params
173
        req.fill_ids = req.origin_input_ids
174
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
175
176
177
178
179
        reqs.append(req)

    return input_ids, reqs


180
181
182
def prepare_extend_inputs_for_correctness_test(
    bench_args, input_ids, reqs, model_runner
):
183
184
    for i in range(len(reqs)):
        req = reqs[i]
185
        req.fill_ids += input_ids[i][bench_args.cut_len :]
186
        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
Ying Sheng's avatar
Ying Sheng committed
187
            i, : bench_args.cut_len
188
        ]
189
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
190
191
192
    return reqs


193
194
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
    input_ids = np.ones((batch_size, input_len), dtype=np.int32)
195
196
197
198
199
200
201
202
203
204
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(input_ids)):
        req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
        req.prefix_indices = []
        req.sampling_params = sampling_params
205
        req.fill_ids = req.origin_input_ids
206
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
207
208
209
210
211
212
        reqs.append(req)

    return reqs


def extend(reqs, model_runner):
213
    batch = ScheduleBatch.init_new(
214
215
216
        reqs=reqs,
        req_to_token_pool=model_runner.req_to_token_pool,
        token_to_kv_pool=model_runner.token_to_kv_pool,
Ying Sheng's avatar
Ying Sheng committed
217
218
        tree_cache=None,
    )
219
    batch.prepare_for_extend(model_runner.model_config.vocab_size)
220
221
    logits_output = model_runner.forward(batch)
    next_token_ids = model_runner.sample(logits_output, batch).tolist()
Liangsheng Yin's avatar
Liangsheng Yin committed
222
    return next_token_ids, logits_output.next_token_logits, batch
223
224
225


def decode(input_token_ids, batch, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
226
    batch.prepare_for_decode(input_token_ids)
227
228
    logits_output = model_runner.forward(batch)
    next_token_ids = model_runner.sample(logits_output, batch).tolist()
Liangsheng Yin's avatar
Liangsheng Yin committed
229
    return next_token_ids, logits_output.next_token_logits
230
231


Ying Sheng's avatar
Ying Sheng committed
232
@torch.inference_mode()
233
234
235
236
237
238
239
240
241
242
243
def correctness_test(
    server_args,
    bench_args,
    tp_rank,
):
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, tp_rank)

    # Prepare inputs
244
    input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
245
    rank_print(f"\n{input_ids=}\n")
246

Ying Sheng's avatar
Ying Sheng committed
247
248
249
    if bench_args.cut_len > 0:
        # Prefill
        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
250
        rank_print(f"prefill logits (first half): {next_token_logits} \n")
251
252

    # Prepare extend inputs
253
254
255
    reqs = prepare_extend_inputs_for_correctness_test(
        bench_args, input_ids, reqs, model_runner
    )
256
257
258

    # Extend
    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
259
    rank_print(f"prefill logits (final): {next_token_logits} \n")
260
261

    # Decode
Ying Sheng's avatar
Ying Sheng committed
262
    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
263
    for _ in range(bench_args.output_len[0] - 1):
264
265
266
267
268
269
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        for i in range(len(reqs)):
            output_ids[i].append(next_token_ids[i])

    # Print
    for i in range(len(reqs)):
270
271
        rank_print(f"========== Prompt {i} ==========")
        rank_print(tokenizer.decode(output_ids[i]), "\n")
272
273


274
275
@torch.inference_mode()
def latency_test_run_once(
276
    run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
277
):
278
279
280
281
282
283
    max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
    if batch_size > max_batch_size:
        rank_print(
            f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
        )
        return
284
285
286
287
288
289

    # Clear the pools.
    model_runner.req_to_token_pool.clear()
    model_runner.token_to_kv_pool.clear()

    measurement_results = {
290
        "run_name": run_name,
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        "batch_size": batch_size,
        "input_len": input_len,
        "output_len": output_len,
    }

    tot_latency = 0

    # Prefill
    torch.cuda.synchronize()
    tic = time.time()
    next_token_ids, _, batch = extend(reqs, model_runner)
    torch.cuda.synchronize()
    prefill_latency = time.time() - tic
    tot_latency += prefill_latency
    throughput = input_len * batch_size / prefill_latency
    rank_print(
        f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["prefill_latency"] = prefill_latency
    measurement_results["prefill_throughput"] = throughput

    # Decode
313
    decode_latencies = []
314
    for i in range(output_len - 1):
315
316
317
318
319
320
321
        torch.cuda.synchronize()
        tic = time.time()
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        torch.cuda.synchronize()
        latency = time.time() - tic
        tot_latency += latency
        throughput = batch_size / latency
322
        decode_latencies.append(latency)
323
324
325
326
        if i < 5:
            rank_print(
                f"Decode.  latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
            )
327
328
    med_decode_latency = np.median(decode_latencies)
    med_decode_throughput = batch_size / med_decode_latency
329
    rank_print(
330
        f"Decode.  median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
331
    )
332
333
    measurement_results["median_decode_latency"] = med_decode_latency
    measurement_results["median_decode_throughput"] = med_decode_throughput
334
335
336
337
338
339
340
341
342
343

    throughput = (input_len + output_len) * batch_size / tot_latency
    rank_print(
        f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["total_latency"] = tot_latency
    measurement_results["total_throughput"] = throughput
    return measurement_results


344
345
346
347
348
def latency_test(
    server_args,
    bench_args,
    tp_rank,
):
349
350
    configure_logger(server_args, prefix=f" TP{tp_rank}")
    _set_envs_and_config(server_args)
351
352
353
354
355
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, tp_rank)

356
    # Prepare inputs for warm up
357
    reqs = prepare_synthetic_inputs_for_latency_test(
358
        bench_args.batch_size[0], bench_args.input_len[0]
359
    )
360
361

    # Warm up
Mingyi's avatar
Mingyi committed
362
    rank_print("Warmup ...")
363
    latency_test_run_once(
364
365
366
367
368
369
370
        bench_args.run_name,
        model_runner,
        rank_print,
        reqs,
        bench_args.batch_size[0],
        bench_args.input_len[0],
        4,  # shorter decoding to speed up the warmup
371
    )
Mingyi's avatar
Mingyi committed
372
    rank_print("Benchmark ...")
373

374
    # Run the sweep
375
    result_list = []
376
377
378
    for bs, il, ol in itertools.product(
        bench_args.batch_size, bench_args.input_len, bench_args.output_len
    ):
Ying Sheng's avatar
Ying Sheng committed
379
        reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
380
381
        ret = latency_test_run_once(
            bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
382
        )
383
384
385
386
387
388
        if ret is not None:
            result_list.append(ret)

    # Write results in jsonlines format on rank 0.
    if tp_rank == 0 and bench_args.result_filename:
        import jsonlines
389
390
391

        with jsonlines.open(bench_args.result_filename, "a") as f:
            f.write_all(result_list)
392
393


394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def plot_latency_test(
    server_args,
    bench_args,
    tp_rank,
):
    assert tp_rank == 0

    # read the jsonl file and put in sqlite
    df = pd.read_json(bench_args.result_filename, lines=True)
    conn = sqlite3.connect(":memory:")
    cur = conn.cursor()

    # get the columns and their types
    column_names = list(df.iloc[0].keys())
    type_dict = {
        str: "TEXT",
        np.int64: "INTEGER",
        np.float64: "FLOAT",
    }
    column_types = [type_dict[type(i)] for i in list(df.iloc[0])]

    # create the table
    cur.execute(
        f"""
        CREATE TABLE IF NOT EXISTS results (
            {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
        )
    """
    )
    conn.commit()

    # write the results to DB
    df.to_sql("results", conn, if_exists="replace", index=False)
    conn.commit()

    # read it back using sql
    df = pd.read_sql_query(bench_args.graph_sql, conn)
    conn.close()

    # plot it and save to a file
    import matplotlib.pyplot as plt

    assert (
        len(df.columns) == 3
    ), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
    for label in df[df.columns[0]].unique():
        q = f"{df.columns[0]}=='{label}'"
        series = df.query(q)
        plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
    plt.xlabel(df.columns[1])
    plt.ylabel(df.columns[2])
    plt.legend()
    plt.savefig(bench_args.graph_filename, dpi=300)

    # if in kitty, just dump it to the terminal
    if os.environ["TERM"] == "xterm-kitty":
        os.system(
            f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
        )


455
456
def main(server_args, bench_args):

457
458
459
460
461
462
463
464
    if server_args.model_path:
        if bench_args.correctness_test:
            work_func = correctness_test
        else:
            work_func = latency_test
    elif os.path.isfile(bench_args.result_filename):
        assert bench_args.graph_filename, "please provide a filename for the graph"
        work_func = plot_latency_test
465
    else:
466
467
468
469
        raise ValueError(
            "Provide --model-path for running the tests or "
            "provide --result-filename for plotting the results"
        )
470

Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    if server_args.tp_size == 1:
        work_func(server_args, bench_args, 0)
    else:
        workers = []
        for tp_rank in range(server_args.tp_size):
            proc = multiprocessing.Process(
                target=work_func,
                args=(
                    server_args,
                    bench_args,
                    tp_rank,
                ),
            )
            proc.start()
            workers.append(proc)

        for proc in workers:
            proc.join()

        proc.terminate()
Lianmin Zheng's avatar
Lianmin Zheng committed
491

492
493
494
495
496
497
498
499
500
501
502
503
504
505

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

506
507
    multiprocessing.set_start_method("spawn", force=True)

508
509
510
511
512
513
    try:
        main(server_args, bench_args)
    except Exception as e:
        raise e
    finally:
        kill_child_process(os.getpid(), including_parent=False)