bench_latency.py 11.9 KB
Newer Older
1
2
3
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.

4
# Usage (latency test) with dummy weights:
5
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
7

# Usage (correctness test):
8
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9

10
### Reference output (of the correctness test above, can be gpu dependent):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
prefill logits (first half) tensor([[-10.0312,  -9.5000,   0.8936,  ...,  -4.9414,  -3.2402,  -3.3633],
        [-10.0312,  -9.5000,   0.8936,  ...,  -4.9414,  -3.2402,  -3.3633],
        [ -9.1875, -10.2500,   2.7109,  ...,  -4.3359,  -4.0664,  -4.1328]],
       device='cuda:0', dtype=torch.float16)
prefill logits (final) tensor([[-8.3203, -7.1211,  3.3379,  ..., -4.9570, -4.1328, -3.4141],
        [-8.9062, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0742],
        [-9.6328, -9.0547,  4.0117,  ..., -5.3047, -4.7148, -4.4609]],
       device='cuda:0', dtype=torch.float16)
<s> The capital of France is.
The capital of the United States is Washington, D.C.

<s> The capital of the United Kindom is.
The capital of the United Kingdom is London.
The capital of the
<s> Today is a sunny day and I like go for a walk in the park.
I'm going to the park
"""

import argparse
import dataclasses
import logging
import multiprocessing
import time
34
from typing import Tuple
35

36
import jsonlines
37
38
39
40
41
import numpy as np
import torch
import torch.distributed as dist

from sglang.srt.hf_transformers_utils import get_tokenizer
42
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
43
from sglang.srt.model_config import ModelConfig
44
from sglang.srt.model_executor.model_runner import ModelRunner
45
46
47
48
49
50
51
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers


@dataclasses.dataclass
class BenchArgs:
52
    batch_size: Tuple[int] = (1,)
53
54
    input_len: int = 1024
    output_len: int = 4
55
    result_filename: str = ""
56
57
58
59
60
61
    correctness_test: bool = False
    # This is only used for correctness test
    cut_len: int = 4

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
62
63
64
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
65
66
        parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
        parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
67
68
69
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
70
71
72
73
74
        parser.add_argument("--correctness-test", action="store_true")
        parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
75
76
77
78
79
        # use the default value's type to case the args into correct types.
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )
80
81
82
83


def load_model(server_args, tp_rank):
    suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
84
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
85
86
87
88
89
90
91
92
93
94
95

    model_config = ModelConfig(path=server_args.model_path)
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=server_args.mem_fraction_static,
        gpu_id=tp_rank,
        tp_rank=tp_rank,
        tp_size=server_args.tp_size,
        nccl_port=28888,
        server_args=server_args,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
96
    rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
97
98
99
100
101
102
103
104
105
106
    tokenizer = get_tokenizer(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )
    if server_args.tp_size > 1:
        dist.barrier()
    return model_runner, tokenizer


107
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
    input_ids = [tokenizer.encode(p) for p in prompts]
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(prompts)):
        assert len(input_ids[i]) > bench_args.cut_len

Ying Sheng's avatar
Ying Sheng committed
123
        tmp_input_ids = input_ids[i][: bench_args.cut_len]
124
125
126
127
128
129
130
131
132
        req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
        req.prefix_indices = []
        req.sampling_params = sampling_params
        req.input_ids = req.origin_input_ids
        reqs.append(req)

    return input_ids, reqs


133
134
135
def prepare_extend_inputs_for_correctness_test(
    bench_args, input_ids, reqs, model_runner
):
136
137
    for i in range(len(reqs)):
        req = reqs[i]
Ying Sheng's avatar
Ying Sheng committed
138
        req.input_ids += input_ids[i][bench_args.cut_len :]
139
        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
Ying Sheng's avatar
Ying Sheng committed
140
            i, : bench_args.cut_len
141
142
143
144
        ]
    return reqs


145
146
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
    input_ids = np.ones((batch_size, input_len), dtype=np.int32)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(input_ids)):
        req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
        req.prefix_indices = []
        req.sampling_params = sampling_params
        req.input_ids = req.origin_input_ids
        reqs.append(req)

    return reqs


def extend(reqs, model_runner):
    batch = Batch.init_new(
        reqs=reqs,
        req_to_token_pool=model_runner.req_to_token_pool,
        token_to_kv_pool=model_runner.token_to_kv_pool,
Ying Sheng's avatar
Ying Sheng committed
168
169
        tree_cache=None,
    )
170
171
    batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
    output = model_runner.forward(batch, ForwardMode.EXTEND)
172
    next_token_ids = batch.sample(output.next_token_logits)
173
174
175
176
177
178
    return next_token_ids, output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
    batch.prepare_for_decode(input_token_ids.cpu().numpy())
    output = model_runner.forward(batch, ForwardMode.DECODE)
179
    next_token_ids = batch.sample(output.next_token_logits)
180
181
182
    return next_token_ids, output.next_token_logits


Ying Sheng's avatar
Ying Sheng committed
183
@torch.inference_mode()
184
185
186
187
188
189
190
191
192
193
194
def correctness_test(
    server_args,
    bench_args,
    tp_rank,
):
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, tp_rank)

    # Prepare inputs
195
    input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
196

Ying Sheng's avatar
Ying Sheng committed
197
198
199
200
    if bench_args.cut_len > 0:
        # Prefill
        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
        rank_print("prefill logits (first half)", next_token_logits)
201
202

    # Prepare extend inputs
203
204
205
    reqs = prepare_extend_inputs_for_correctness_test(
        bench_args, input_ids, reqs, model_runner
    )
206
207
208
209
210
211

    # Extend
    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
    rank_print("prefill logits (final)", next_token_logits)

    # Decode
Ying Sheng's avatar
Ying Sheng committed
212
    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
213
214
215
216
217
218
219
    for _ in range(bench_args.output_len):
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        for i in range(len(reqs)):
            output_ids[i].append(next_token_ids[i])

    # Print
    for i in range(len(reqs)):
Lianmin Zheng's avatar
Lianmin Zheng committed
220
        rank_print(tokenizer.decode(output_ids[i]))
221
222


223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@torch.inference_mode()
def latency_test_run_once(
    model_runner, rank_print, reqs, batch_size, input_len, output_len
):

    # Clear the pools.
    model_runner.req_to_token_pool.clear()
    model_runner.token_to_kv_pool.clear()

    measurement_results = {
        "run_name": "before",
        "batch_size": batch_size,
        "input_len": input_len,
        "output_len": output_len,
    }

    tot_latency = 0

    # Prefill
    torch.cuda.synchronize()
    tic = time.time()
    next_token_ids, _, batch = extend(reqs, model_runner)
    torch.cuda.synchronize()
    prefill_latency = time.time() - tic
    tot_latency += prefill_latency
    throughput = input_len * batch_size / prefill_latency
    rank_print(
        f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["prefill_latency"] = prefill_latency
    measurement_results["prefill_throughput"] = throughput

    # Decode
    for i in range(output_len):
        torch.cuda.synchronize()
        tic = time.time()
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        torch.cuda.synchronize()
        latency = time.time() - tic
        tot_latency += latency
        throughput = batch_size / latency
        if i < 5:
            rank_print(
                f"Decode.  latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
            )
    avg_decode_latency = (tot_latency - prefill_latency) / output_len
    avg_decode_throughput = batch_size / avg_decode_latency
    rank_print(
        f"Decode.  avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
    )
    measurement_results["avg_decode_latency"] = avg_decode_latency
    measurement_results["avg_decode_throughput"] = avg_decode_throughput

    throughput = (input_len + output_len) * batch_size / tot_latency
    rank_print(
        f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["total_latency"] = tot_latency
    measurement_results["total_throughput"] = throughput
    return measurement_results


285
286
287
288
289
290
291
292
293
def latency_test(
    server_args,
    bench_args,
    tp_rank,
):
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, tp_rank)
Lianmin Zheng's avatar
Lianmin Zheng committed
294
    rank_print(
Ying Sheng's avatar
Ying Sheng committed
295
296
        f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
    )
297

298
299
300
    # To make this PR easier to review, for now, only do the first element in batch_size tuple.
    bench_args.batch_size = bench_args.batch_size[0]

301
    # Prepare inputs
302
303
304
    reqs = prepare_synthetic_inputs_for_latency_test(
        bench_args.batch_size, bench_args.input_len
    )
305
306

    # Warm up
307
308
309
    latency_test_run_once(
        model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
    )
310
311

    # Run again
312
    result_list = []
313
314
315
316
317
318
319
320
321
322
    result_list.append(
        latency_test_run_once(
            model_runner,
            rank_print,
            reqs,
            bench_args.batch_size,
            bench_args.input_len,
            bench_args.output_len,
        )
    )
323
324
325
326
327

    # Write results in jsonlines format.
    if bench_args.result_filename:
        with jsonlines.open(bench_args.result_filename, "a") as f:
            f.write_all(result_list)
328
329
330
331
332
333
334
335
336
337


def main(server_args, bench_args):
    print(bench_args)

    if bench_args.correctness_test:
        work_func = correctness_test
    else:
        work_func = latency_test

Lianmin Zheng's avatar
Lianmin Zheng committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    if server_args.tp_size == 1:
        work_func(server_args, bench_args, 0)
    else:
        workers = []
        for tp_rank in range(server_args.tp_size):
            proc = multiprocessing.Process(
                target=work_func,
                args=(
                    server_args,
                    bench_args,
                    tp_rank,
                ),
            )
            proc.start()
            workers.append(proc)

        for proc in workers:
            proc.join()

        proc.terminate()
Lianmin Zheng's avatar
Lianmin Zheng committed
358

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()

    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

Ying Sheng's avatar
Ying Sheng committed
374
    main(server_args, bench_args)