bench_one_batch.py 22.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Benchmark the latency of running a single static batch without a server.

This script does not launch a server and uses the low-level APIs.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).

# Usage (latency test)
## with dummy weights:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
13
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct

## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]

prefill logits (first half): tensor([[-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [ -9.1875, -10.2500,   2.7129,  ...,  -4.3359,  -4.0664,  -4.1328]],
       device='cuda:0')

prefill logits (final): tensor([[-8.3125, -7.1172,  3.3457,  ..., -4.9570, -4.1328, -3.4141],
        [-8.9141, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0781],
        [-9.6328, -9.0547,  4.0195,  ..., -5.3047, -4.7148, -4.4570]],
       device='cuda:0')

========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.


========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the

========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""

import argparse
46
import copy
47
48
49
50
51
import dataclasses
import itertools
import json
import logging
import multiprocessing
52
import os
53
import time
54
from types import SimpleNamespace
55
56
57
58
59
60
61
from typing import Tuple

import numpy as np
import torch
import torch.distributed as dist

from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
63
from sglang.srt.entrypoints.engine import _set_envs_and_config
64
from sglang.srt.layers.moe import initialize_moe_config
65
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
66
from sglang.srt.managers.scheduler import Scheduler
67
68
69
70
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
71
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
72
73
74
from sglang.srt.utils import (
    configure_logger,
    get_bool_env_var,
Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
75
76
    is_cuda_alike,
    is_xpu,
77
    kill_process_tree,
78
79
    require_mlp_sync,
    require_mlp_tp_gather,
80
81
82
    set_gpu_proc_affinity,
    suppress_other_loggers,
)
83
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
84

Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
85
86
87
88
89
90
91
92
93
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
    profiler_activity
    for available, profiler_activity in [
        (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
        (is_xpu(), torch.profiler.ProfilerActivity.XPU),
    ]
    if available
]

94
95
96
97
98
99
100

@dataclasses.dataclass
class BenchArgs:
    run_name: str = "default"
    batch_size: Tuple[int] = (1,)
    input_len: Tuple[int] = (1024,)
    output_len: Tuple[int] = (16,)
101
    prompt_filename: str = ""
102
103
104
105
    result_filename: str = "result.jsonl"
    correctness_test: bool = False
    # This is only used for correctness test
    cut_len: int = 4
106
    log_decode_step: int = 0
107
    profile: bool = False
108
    profile_record_shapes: bool = False
109
    profile_filename_prefix: str = "profile"
110
111
112
113
114
115
116
117
118
119
120
121
122

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
123
124
125
        parser.add_argument(
            "--prompt-filename", type=str, default=BenchArgs.prompt_filename
        )
126
127
128
129
130
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
        parser.add_argument("--correctness-test", action="store_true")
        parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
131
132
133
134
135
136
        parser.add_argument(
            "--log-decode-step",
            type=int,
            default=BenchArgs.log_decode_step,
            help="Log decode latency by step, default is set to zero to disable.",
        )
137
        parser.add_argument(
138
            "--profile", action="store_true", help="Use Torch Profiler."
139
        )
140
141
142
143
144
        parser.add_argument(
            "--profile-record-shapes",
            action="store_true",
            help="Record tensor shapes in profiling results.",
        )
145
146
147
148
149
150
151
        parser.add_argument(
            "--profile-filename-prefix",
            type=str,
            default=BenchArgs.profile_filename_prefix,
            help="Prefix of the profiling file names. The full profiling result file(s) be "
            '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
        )
152
153
154

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
fzyzcjy's avatar
fzyzcjy committed
155
        # use the default value's type to cast the args into correct types.
156
157
158
159
160
161
162
163
164
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )


def load_model(server_args, port_args, tp_rank):
    suppress_other_loggers()
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
Cheng Wan's avatar
Cheng Wan committed
165
    moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
166

167
    model_config = ModelConfig.from_server_args(server_args)
168
169
170
171
172
173
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=server_args.mem_fraction_static,
        gpu_id=tp_rank,
        tp_rank=tp_rank,
        tp_size=server_args.tp_size,
Cheng Wan's avatar
Cheng Wan committed
174
175
        moe_ep_rank=moe_ep_rank,
        moe_ep_size=server_args.ep_size,
176
177
        pp_rank=0,
        pp_size=1,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        nccl_port=port_args.nccl_port,
        server_args=server_args,
    )
    rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
    tokenizer = get_tokenizer(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )
    if server_args.tp_size > 1:
        dist.barrier()
    return model_runner, tokenizer


192
193
194
195
196
197
198
199
200
201
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
    prompts = (
        custom_prompts
        if custom_prompts
        else [
            "The capital of France is",
            "The capital of the United Kindom is",
            "Today is a sunny day and I like",
        ]
    )
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    input_ids = [tokenizer.encode(p) for p in prompts]
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(prompts)):
        assert len(input_ids[i]) > bench_args.cut_len

        tmp_input_ids = input_ids[i][: bench_args.cut_len]
        req = Req(
            rid=i,
            origin_input_text=prompts[i],
            origin_input_ids=tmp_input_ids,
            sampling_params=sampling_params,
        )
        req.fill_ids = req.origin_input_ids
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
221
        req.logprob_start_len = len(req.origin_input_ids) - 1
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        reqs.append(req)

    return input_ids, reqs


def prepare_extend_inputs_for_correctness_test(
    bench_args, input_ids, reqs, model_runner
):
    for i in range(len(reqs)):
        req = reqs[i]
        req.fill_ids += input_ids[i][bench_args.cut_len :]
        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
            i, : bench_args.cut_len
        ]
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
237
        req.logprob_start_len = len(req.origin_input_ids) - 1
238
239
240
    return reqs


241
242
243
244
245
246
247
248
def prepare_synthetic_inputs_for_latency_test(
    batch_size, input_len, custom_inputs=None
):
    input_ids = (
        custom_inputs
        if custom_inputs
        else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
    )
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(input_ids)):
        req = Req(
            rid=i,
            origin_input_text="",
            origin_input_ids=list(input_ids[i]),
            sampling_params=sampling_params,
        )
        req.fill_ids = req.origin_input_ids
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
264
        req.logprob_start_len = len(req.origin_input_ids) - 1
265
266
267
268
269
270
271
        reqs.append(req)

    return reqs


@torch.no_grad
def extend(reqs, model_runner):
272
273
274
275
276
277
278
    # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
    dummy_tree_cache = SimpleNamespace(
        page_size=1,
        device=model_runner.device,
        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
    )

279
280
281
    batch = ScheduleBatch.init_new(
        reqs=reqs,
        req_to_token_pool=model_runner.req_to_token_pool,
282
        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
283
        tree_cache=dummy_tree_cache,
284
        model_config=model_runner.model_config,
285
        enable_overlap=False,
286
        spec_algorithm=SpeculativeAlgorithm.NONE,
287
288
    )
    batch.prepare_for_extend()
289
    _maybe_prepare_mlp_sync_batch(batch, model_runner)
290
291
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
292
    logits_output, _ = model_runner.forward(forward_batch)
293
294
295
296
297
298
299
300
    next_token_ids = model_runner.sample(logits_output, forward_batch)
    return next_token_ids, logits_output.next_token_logits, batch


@torch.no_grad
def decode(input_token_ids, batch, model_runner):
    batch.output_ids = input_token_ids
    batch.prepare_for_decode()
301
    _maybe_prepare_mlp_sync_batch(batch, model_runner)
302
303
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
304
    logits_output, _ = model_runner.forward(forward_batch)
305
306
307
308
    next_token_ids = model_runner.sample(logits_output, forward_batch)
    return next_token_ids, logits_output.next_token_logits


309
310
311
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
    if require_mlp_sync(model_runner.server_args):
        Scheduler.prepare_mlp_sync_batch_raw(
312
313
314
            batch,
            dp_size=model_runner.server_args.dp_size,
            attn_tp_size=1,
315
            tp_group=model_runner.tp_group,
316
317
318
319
            get_idle_batch=None,
            disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
            spec_algorithm=SpeculativeAlgorithm.NONE,
            speculative_num_draft_tokens=None,
320
            require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
321
            disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
322
323
324
        )


325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def _read_prompts_from_file(prompt_file, rank_print):
    """Read custom prompts from the file specified by `--prompt-filename`."""
    if not prompt_file:
        return []
    if not os.path.exists(prompt_file):
        rank_print(
            f"Custom prompt file {prompt_file} not found. Using default inputs..."
        )
        return []
    with open(prompt_file, "r") as pf:
        return pf.readlines()


def _save_profile_trace_results(profiler, filename):
    parent_dir = os.path.dirname(os.path.abspath(filename))
    os.makedirs(parent_dir, exist_ok=True)
    profiler.export_chrome_trace(filename)
    print(
        profiler.key_averages(group_by_input_shape=True).table(
            sort_by="self_cpu_time_total"
        )
    )


349
350
351
352
353
354
355
356
357
358
359
360
361
362
def correctness_test(
    server_args,
    port_args,
    bench_args,
    tp_rank,
):
    # Configure the logger
    configure_logger(server_args, prefix=f" TP{tp_rank}")
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, port_args, tp_rank)

    # Prepare inputs
363
364
365
366
    custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
    input_ids, reqs = prepare_inputs_for_correctness_test(
        bench_args, tokenizer, custom_prompts
    )
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    rank_print(f"\n{input_ids=}\n")

    if bench_args.cut_len > 0:
        # Prefill
        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
        rank_print(f"prefill logits (first half): {next_token_logits} \n")

    # Prepare extend inputs
    reqs = prepare_extend_inputs_for_correctness_test(
        bench_args, input_ids, reqs, model_runner
    )

    # Extend (prefill w/ KV cache)
    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
    rank_print(f"prefill logits (final): {next_token_logits} \n")

    # Decode
    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
    for _ in range(bench_args.output_len[0] - 1):
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        next_token_ids_list = next_token_ids.tolist()
        for i in range(len(reqs)):
            output_ids[i].append(next_token_ids_list[i])

    # Print output texts
    for i in range(len(reqs)):
        rank_print(f"========== Prompt {i} ==========")
        rank_print(tokenizer.decode(output_ids[i]), "\n")


def synchronize(device):
398
    torch.get_device_module(device).synchronize()
399
400
401


def latency_test_run_once(
402
403
404
405
406
407
408
409
    run_name,
    model_runner,
    rank_print,
    reqs,
    batch_size,
    input_len,
    output_len,
    device,
410
    log_decode_step,
411
    profile,
412
    profile_record_shapes,
413
    profile_filename_prefix,
414
415
416
417
418
419
420
421
422
423
):
    max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
    if batch_size > max_batch_size:
        rank_print(
            f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
        )
        return

    # Clear the pools.
    model_runner.req_to_token_pool.clear()
424
    model_runner.token_to_kv_pool_allocator.clear()
425
426
427
428
429
430
431
432
433
434

    measurement_results = {
        "run_name": run_name,
        "batch_size": batch_size,
        "input_len": input_len,
        "output_len": output_len,
    }

    tot_latency = 0

435
436
437
    profiler = None
    if profile:
        profiler = torch.profiler.profile(
Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
438
            activities=profile_activities,
439
            with_stack=True,
440
            record_shapes=profile_record_shapes,
441
442
443
        )
        profiler.start()

444
445
    # Prefill
    synchronize(device)
446
    tic = time.perf_counter()
447
448
    next_token_ids, _, batch = extend(reqs, model_runner)
    synchronize(device)
449
    prefill_latency = time.perf_counter() - tic
450
451
452
453
454
455
456
457
    tot_latency += prefill_latency
    throughput = input_len * batch_size / prefill_latency
    rank_print(
        f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["prefill_latency"] = prefill_latency
    measurement_results["prefill_throughput"] = throughput

458
459
    if profile:
        profiler.stop()
Mick's avatar
Mick committed
460
461
462
        trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
        _save_profile_trace_results(profiler, trace_filename)
        rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
463

464
465
466
467
    # Decode
    decode_latencies = []
    for i in range(output_len - 1):
        synchronize(device)
468
469
470
        if profile and i == output_len / 2:
            profiler = None
            profiler = torch.profiler.profile(
Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
471
                activities=profile_activities,
472
473
474
475
476
                with_stack=True,
                record_shapes=profile_record_shapes,
            )
            profiler.start()

477
        tic = time.perf_counter()
478
479
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        synchronize(device)
480
        latency = time.perf_counter() - tic
481
482
483
        tot_latency += latency
        throughput = batch_size / latency
        decode_latencies.append(latency)
484
        if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
485
            rank_print(
486
                f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
487
488
            )

489
490
        if profile and i == output_len / 2:
            profiler.stop()
Mick's avatar
Mick committed
491
492
            trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
            _save_profile_trace_results(profiler, trace_filename)
493
            rank_print(
Mick's avatar
Mick committed
494
                f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
495
            )
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    # Record decode timing from 2nd output
    if output_len > 1:
        med_decode_latency = np.median(decode_latencies)
        med_decode_throughput = batch_size / med_decode_latency
        rank_print(
            f"Decode.  median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
        )
        measurement_results["median_decode_latency"] = med_decode_latency
        measurement_results["median_decode_throughput"] = med_decode_throughput

    throughput = (input_len + output_len) * batch_size / tot_latency
    rank_print(
        f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["total_latency"] = tot_latency
    measurement_results["overall_throughput"] = throughput
    return measurement_results


def latency_test(
    server_args,
    port_args,
    bench_args,
    tp_rank,
):
522
523
    initialize_moe_config(server_args)

524
525
    # Set CPU affinity
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
526
527
528
        set_gpu_proc_affinity(
            server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
        )
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    # Configure the logger
    configure_logger(server_args, prefix=f" TP{tp_rank}")
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
    model_runner, tokenizer = load_model(server_args, port_args, tp_rank)

    # Prepare inputs for warm up
    reqs = prepare_synthetic_inputs_for_latency_test(
        bench_args.batch_size[0], bench_args.input_len[0]
    )

    # Warm up
    rank_print("Warmup ...")
    latency_test_run_once(
        bench_args.run_name,
        model_runner,
        rank_print,
        reqs,
        bench_args.batch_size[0],
        bench_args.input_len[0],
551
        min(32, bench_args.output_len[0]),  # shorter decoding to speed up the warmup
552
        server_args.device,
553
        log_decode_step=0,
554
        profile=False,
555
        profile_record_shapes=False,
556
        profile_filename_prefix="",  # not used
557
    )
558

559
560
    rank_print("Benchmark ...")

561
562
563
564
    custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
    custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
    custom_input_len = len(custom_inputs)

565
566
567
568
569
    # Run the sweep
    result_list = []
    for bs, il, ol in itertools.product(
        bench_args.batch_size, bench_args.input_len, bench_args.output_len
    ):
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        bs_aligned_inputs = []
        if custom_inputs:
            if custom_input_len == bs:
                bs_aligned_inputs = custom_inputs
            elif custom_input_len > bs:
                rank_print(
                    f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
                    f"Using the first {bs} prompts."
                )
                bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
            else:
                rank_print(
                    f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
                    f"Pad to the desired batch_size with the last prompt."
                )
                bs_aligned_inputs = copy.deepcopy(custom_inputs)
                bs_aligned_inputs.extend(
                    [bs_aligned_inputs[-1]] * (bs - custom_input_len)
                )

        reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
591
592
593
594
595
596
597
598
599
        ret = latency_test_run_once(
            bench_args.run_name,
            model_runner,
            rank_print,
            reqs,
            bs,
            il,
            ol,
            server_args.device,
600
            bench_args.log_decode_step,
601
            bench_args.profile if tp_rank == 0 else None,
602
            bench_args.profile_record_shapes if tp_rank == 0 else None,
603
            bench_args.profile_filename_prefix,
604
605
606
607
608
609
610
611
612
613
        )
        if ret is not None:
            result_list.append(ret)

    # Write results in jsonlines format on rank 0.
    if tp_rank == 0 and bench_args.result_filename:
        with open(bench_args.result_filename, "a") as fout:
            for result in result_list:
                fout.write(json.dumps(result) + "\n")

Lianmin Zheng's avatar
Lianmin Zheng committed
614
615
616
    if server_args.tp_size > 1:
        destroy_distributed_environment()

617
618

def main(server_args, bench_args):
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
    server_args.cuda_graph_max_bs = max(bench_args.batch_size)

621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
    _set_envs_and_config(server_args)

    if server_args.model_path:
        if bench_args.correctness_test:
            work_func = correctness_test
        else:
            work_func = latency_test
    else:
        raise ValueError(
            "Provide --model-path for running the tests or "
            "provide --result-filename for plotting the results"
        )

    port_args = PortArgs.init_new(server_args)

    if server_args.tp_size == 1:
        work_func(server_args, port_args, bench_args, 0)
    else:
        workers = []
        for tp_rank in range(server_args.tp_size):
            proc = multiprocessing.Process(
                target=work_func,
                args=(
                    server_args,
                    port_args,
                    bench_args,
                    tp_rank,
                ),
            )
            proc.start()
            workers.append(proc)

        for proc in workers:
            proc.join()

        proc.terminate()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

    try:
        main(server_args, bench_args)
    finally:
Lianmin Zheng's avatar
Lianmin Zheng committed
675
        if server_args.tp_size != 1:
676
            kill_process_tree(os.getpid(), include_parent=False)