bench_one_batch.py 27.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Benchmark the latency of running a single static batch without a server.

This script does not launch a server and uses the low-level APIs.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).

# Usage (latency test)
## with dummy weights:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
13
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
15
16
17
18
## run with profiling to custom directory:
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
## run with CUDA profiler (nsys):
nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profiler_activities CUDA_PROFILER
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct

## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]

prefill logits (first half): tensor([[-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [-10.0312,  -9.5000,   0.8931,  ...,  -4.9414,  -3.2422,  -3.3633],
        [ -9.1875, -10.2500,   2.7129,  ...,  -4.3359,  -4.0664,  -4.1328]],
       device='cuda:0')

prefill logits (final): tensor([[-8.3125, -7.1172,  3.3457,  ..., -4.9570, -4.1328, -3.4141],
        [-8.9141, -9.0156,  4.1445,  ..., -4.9922, -4.4961, -4.0781],
        [-9.6328, -9.0547,  4.0195,  ..., -5.3047, -4.7148, -4.4570]],
       device='cuda:0')

========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.


========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the

========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""

import argparse
51
import copy
52
53
54
55
56
import dataclasses
import itertools
import json
import logging
import multiprocessing
57
import os
58
import time
59
from types import SimpleNamespace
60
61
62
63
64
65
66
from typing import Tuple

import numpy as np
import torch
import torch.distributed as dist

from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
67
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
68
from sglang.srt.entrypoints.engine import _set_envs_and_config
69
from sglang.srt.layers.moe import initialize_moe_config
70
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
71
from sglang.srt.managers.scheduler import Scheduler
72
73
74
75
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
76
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
78
79
from sglang.srt.utils import (
    configure_logger,
    get_bool_env_var,
Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
80
81
    is_cuda_alike,
    is_xpu,
82
    kill_process_tree,
83
    maybe_reindex_device_id,
84
85
    require_mlp_sync,
    require_mlp_tp_gather,
86
87
88
    set_gpu_proc_affinity,
    suppress_other_loggers,
)
89
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
90

Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
91
92
93
94
95
96
97
98
99
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
    profiler_activity
    for available, profiler_activity in [
        (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
        (is_xpu(), torch.profiler.ProfilerActivity.XPU),
    ]
    if available
]

100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def start_profile(profiler_activities, profile_record_shapes=False, rank_print=print):
    """
    Abstracted function to start profiling based on profiler_activities.
    Returns profiler object (or None).
    """
    if "CUDA_PROFILER" in profiler_activities:
        try:
            torch.cuda.cudart().cudaProfilerStart()
            rank_print("CUDA Profiler started (nsys will begin capturing)")
        except Exception as e:
            rank_print(f"Failed to start CUDA profiler: {e}")
        return None
    else:
        activities = []
        if "CPU" in profiler_activities:
            activities.append(torch.profiler.ProfilerActivity.CPU)
        if "GPU" in profiler_activities:
            activities.append(torch.profiler.ProfilerActivity.CUDA)
        if activities:
            profiler = torch.profiler.profile(
                activities=activities,
                with_stack=True,
                record_shapes=profile_record_shapes,
            )
            profiler.start()
            return profiler
        return None


def stop_profile(
    profiler,
    profiler_activities,
    rank_print=print,
    save_trace=False,
    trace_filename=None,
    stage=None,
):
    """
    Abstracted function to stop profiling based on profiler_activities.
    Optionally saves trace results and prints completion messages.
    """
    if "CUDA_PROFILER" in profiler_activities:
        try:
            torch.cuda.cudart().cudaProfilerStop()
            rank_print("CUDA Profiler stopped (nsys should dump traces)")
        except Exception as e:
            rank_print(f"Failed to stop CUDA profiler: {e}")
    elif profiler is not None:
        profiler.stop()

    if save_trace:
        if profiler is not None:
            if trace_filename:
                _save_profile_trace_results(profiler, trace_filename)
                stage_desc = f"for {stage}" if stage else ""
                rank_print(
                    f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
                )
        if "CUDA_PROFILER" in profiler_activities:
            rank_print(f"CUDA profiler trace for {stage} completed")


163
164
165
166
167
168
@dataclasses.dataclass
class BenchArgs:
    run_name: str = "default"
    batch_size: Tuple[int] = (1,)
    input_len: Tuple[int] = (1024,)
    output_len: Tuple[int] = (16,)
169
    prompt_filename: str = ""
170
171
172
173
    result_filename: str = "result.jsonl"
    correctness_test: bool = False
    # This is only used for correctness test
    cut_len: int = 4
174
    log_decode_step: int = 0
175
    profile: bool = False
176
    profile_record_shapes: bool = False
177
178
    profiler_activities: Tuple[str] = ("CPU", "GPU")
    profile_stage: str = "all"
179
    profile_filename_prefix: str = "profile"
180
181
182
183
184
185
186
187
188
189
190
191
192

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
        parser.add_argument(
            "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
        )
        parser.add_argument(
            "--input-len", type=int, nargs="+", default=BenchArgs.input_len
        )
        parser.add_argument(
            "--output-len", type=int, nargs="+", default=BenchArgs.output_len
        )
193
194
195
        parser.add_argument(
            "--prompt-filename", type=str, default=BenchArgs.prompt_filename
        )
196
197
198
199
200
        parser.add_argument(
            "--result-filename", type=str, default=BenchArgs.result_filename
        )
        parser.add_argument("--correctness-test", action="store_true")
        parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
201
202
203
204
205
206
        parser.add_argument(
            "--log-decode-step",
            type=int,
            default=BenchArgs.log_decode_step,
            help="Log decode latency by step, default is set to zero to disable.",
        )
207
        parser.add_argument("--profile", action="store_true", help="Enable profiling.")
208
209
210
211
212
        parser.add_argument(
            "--profile-record-shapes",
            action="store_true",
            help="Record tensor shapes in profiling results.",
        )
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        parser.add_argument(
            "--profiler_activities",
            type=str,
            nargs="+",
            default=["CPU", "GPU"],
            choices=["CPU", "GPU", "CUDA_PROFILER"],
            help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
        )
        parser.add_argument(
            "--profile-stage",
            type=str,
            default=BenchArgs.profile_stage,
            choices=["all", "prefill", "decode"],
            help="Which stage to profile: all, prefill, or decode only.",
        )
228
229
230
231
232
233
234
        parser.add_argument(
            "--profile-filename-prefix",
            type=str,
            default=BenchArgs.profile_filename_prefix,
            help="Prefix of the profiling file names. The full profiling result file(s) be "
            '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
        )
235
236
237

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
fzyzcjy's avatar
fzyzcjy committed
238
        # use the default value's type to cast the args into correct types.
239
240
241
242
243
244
        attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
        return cls(
            **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
        )


245
def load_model(server_args, port_args, gpu_id, tp_rank):
246
247
    suppress_other_loggers()
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
Cheng Wan's avatar
Cheng Wan committed
248
    moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
249

250
    model_config = ModelConfig.from_server_args(server_args)
251
252
253
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=server_args.mem_fraction_static,
254
        gpu_id=gpu_id,
255
256
        tp_rank=tp_rank,
        tp_size=server_args.tp_size,
Cheng Wan's avatar
Cheng Wan committed
257
258
        moe_ep_rank=moe_ep_rank,
        moe_ep_size=server_args.ep_size,
259
260
        pp_rank=0,
        pp_size=1,
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        nccl_port=port_args.nccl_port,
        server_args=server_args,
    )
    rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
    tokenizer = get_tokenizer(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )
    if server_args.tp_size > 1:
        dist.barrier()
    return model_runner, tokenizer


275
276
277
278
279
280
281
282
283
284
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
    prompts = (
        custom_prompts
        if custom_prompts
        else [
            "The capital of France is",
            "The capital of the United Kindom is",
            "Today is a sunny day and I like",
        ]
    )
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    input_ids = [tokenizer.encode(p) for p in prompts]
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(prompts)):
        assert len(input_ids[i]) > bench_args.cut_len

        tmp_input_ids = input_ids[i][: bench_args.cut_len]
        req = Req(
            rid=i,
            origin_input_text=prompts[i],
            origin_input_ids=tmp_input_ids,
            sampling_params=sampling_params,
        )
        req.fill_ids = req.origin_input_ids
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
304
        req.logprob_start_len = len(req.origin_input_ids) - 1
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        reqs.append(req)

    return input_ids, reqs


def prepare_extend_inputs_for_correctness_test(
    bench_args, input_ids, reqs, model_runner
):
    for i in range(len(reqs)):
        req = reqs[i]
        req.fill_ids += input_ids[i][bench_args.cut_len :]
        req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
            i, : bench_args.cut_len
        ]
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
320
        req.logprob_start_len = len(req.origin_input_ids) - 1
321
322
323
    return reqs


324
325
326
327
328
329
330
331
def prepare_synthetic_inputs_for_latency_test(
    batch_size, input_len, custom_inputs=None
):
    input_ids = (
        custom_inputs
        if custom_inputs
        else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
    )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    sampling_params = SamplingParams(
        temperature=0,
        max_new_tokens=BenchArgs.output_len,
    )

    reqs = []
    for i in range(len(input_ids)):
        req = Req(
            rid=i,
            origin_input_text="",
            origin_input_ids=list(input_ids[i]),
            sampling_params=sampling_params,
        )
        req.fill_ids = req.origin_input_ids
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
347
        req.logprob_start_len = len(req.origin_input_ids) - 1
348
349
350
351
352
353
354
        reqs.append(req)

    return reqs


@torch.no_grad
def extend(reqs, model_runner):
355
356
    # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
    dummy_tree_cache = SimpleNamespace(
357
        page_size=model_runner.server_args.page_size,
358
359
360
361
        device=model_runner.device,
        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
    )

362
363
364
    batch = ScheduleBatch.init_new(
        reqs=reqs,
        req_to_token_pool=model_runner.req_to_token_pool,
365
        token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
366
        tree_cache=dummy_tree_cache,
367
        model_config=model_runner.model_config,
368
        enable_overlap=False,
369
        spec_algorithm=SpeculativeAlgorithm.NONE,
370
371
    )
    batch.prepare_for_extend()
372
    _maybe_prepare_mlp_sync_batch(batch, model_runner)
373
374
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
375
    logits_output, _ = model_runner.forward(forward_batch)
376
377
378
379
380
381
382
383
    next_token_ids = model_runner.sample(logits_output, forward_batch)
    return next_token_ids, logits_output.next_token_logits, batch


@torch.no_grad
def decode(input_token_ids, batch, model_runner):
    batch.output_ids = input_token_ids
    batch.prepare_for_decode()
384
    _maybe_prepare_mlp_sync_batch(batch, model_runner)
385
386
    model_worker_batch = batch.get_model_worker_batch()
    forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
387
    logits_output, _ = model_runner.forward(forward_batch)
388
389
390
391
    next_token_ids = model_runner.sample(logits_output, forward_batch)
    return next_token_ids, logits_output.next_token_logits


392
393
394
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
    if require_mlp_sync(model_runner.server_args):
        Scheduler.prepare_mlp_sync_batch_raw(
395
396
397
            batch,
            dp_size=model_runner.server_args.dp_size,
            attn_tp_size=1,
398
            tp_group=model_runner.tp_group,
399
400
401
402
            get_idle_batch=None,
            disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
            spec_algorithm=SpeculativeAlgorithm.NONE,
            speculative_num_draft_tokens=None,
403
            require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
404
            disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
405
            offload_tags=set(),
406
407
408
        )


409
410
411
412
413
414
415
416
417
418
419
420
421
def _read_prompts_from_file(prompt_file, rank_print):
    """Read custom prompts from the file specified by `--prompt-filename`."""
    if not prompt_file:
        return []
    if not os.path.exists(prompt_file):
        rank_print(
            f"Custom prompt file {prompt_file} not found. Using default inputs..."
        )
        return []
    with open(prompt_file, "r") as pf:
        return pf.readlines()


422
423
424
425
426
427
428
429
430
431
432
433
def _get_torch_profiler_output_dir():
    return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")


def _create_torch_profiler_filename(
    profile_filename_prefix, batch_size, input_len, output_len, stage
):
    output_dir = _get_torch_profiler_output_dir()
    filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
    return os.path.join(output_dir, filename)


434
435
436
437
438
439
440
441
442
443
444
def _save_profile_trace_results(profiler, filename):
    parent_dir = os.path.dirname(os.path.abspath(filename))
    os.makedirs(parent_dir, exist_ok=True)
    profiler.export_chrome_trace(filename)
    print(
        profiler.key_averages(group_by_input_shape=True).table(
            sort_by="self_cpu_time_total"
        )
    )


445
446
447
448
def correctness_test(
    server_args,
    port_args,
    bench_args,
449
    gpu_id,
450
451
452
453
454
455
456
    tp_rank,
):
    # Configure the logger
    configure_logger(server_args, prefix=f" TP{tp_rank}")
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
457
    model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
458
459

    # Prepare inputs
460
461
462
463
    custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
    input_ids, reqs = prepare_inputs_for_correctness_test(
        bench_args, tokenizer, custom_prompts
    )
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    rank_print(f"\n{input_ids=}\n")

    if bench_args.cut_len > 0:
        # Prefill
        next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
        rank_print(f"prefill logits (first half): {next_token_logits} \n")

    # Prepare extend inputs
    reqs = prepare_extend_inputs_for_correctness_test(
        bench_args, input_ids, reqs, model_runner
    )

    # Extend (prefill w/ KV cache)
    next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
    rank_print(f"prefill logits (final): {next_token_logits} \n")

    # Decode
    output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
    for _ in range(bench_args.output_len[0] - 1):
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        next_token_ids_list = next_token_ids.tolist()
        for i in range(len(reqs)):
            output_ids[i].append(next_token_ids_list[i])

    # Print output texts
    for i in range(len(reqs)):
        rank_print(f"========== Prompt {i} ==========")
        rank_print(tokenizer.decode(output_ids[i]), "\n")


def synchronize(device):
495
    torch.get_device_module(device).synchronize()
496
497
498


def latency_test_run_once(
499
500
501
502
503
504
505
506
    run_name,
    model_runner,
    rank_print,
    reqs,
    batch_size,
    input_len,
    output_len,
    device,
507
    log_decode_step,
508
    profile,
509
    profile_record_shapes,
510
    profiler_activities,
511
    profile_filename_prefix,
512
513
    profile_stage,
    tp_rank,
514
515
516
517
518
519
520
521
522
):
    max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
    if batch_size > max_batch_size:
        rank_print(
            f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
        )
        return

    model_runner.req_to_token_pool.clear()
523
    model_runner.token_to_kv_pool_allocator.clear()
524
525
526
527
528
529
530
531
532
533

    measurement_results = {
        "run_name": run_name,
        "batch_size": batch_size,
        "input_len": input_len,
        "output_len": output_len,
    }

    tot_latency = 0

534
    profiler = None
535
536
537
538
539
540
    enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
    if enable_profile_prefill:
        profiler = start_profile(
            profiler_activities,
            profile_record_shapes=profile_record_shapes,
            rank_print=rank_print,
541
542
        )

543
    synchronize(device)
544
    tic = time.perf_counter()
545
546
    next_token_ids, _, batch = extend(reqs, model_runner)
    synchronize(device)
547
    prefill_latency = time.perf_counter() - tic
548
549
550
551
552
553
554
555
556
557
558
559
560
561

    if enable_profile_prefill:
        trace_filename = _create_torch_profiler_filename(
            profile_filename_prefix, batch_size, input_len, output_len, "prefill"
        )
        stop_profile(
            profiler,
            profiler_activities,
            rank_print=rank_print,
            save_trace=True,
            trace_filename=trace_filename,
            stage="prefill",
        )

562
563
564
565
566
567
568
569
570
    tot_latency += prefill_latency
    throughput = input_len * batch_size / prefill_latency
    rank_print(
        f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["prefill_latency"] = prefill_latency
    measurement_results["prefill_throughput"] = throughput

    decode_latencies = []
571
572
    profile_step_of_interest = output_len // 2
    enable_profile_decode = profile and profile_stage in ["all", "decode"]
573
574
    for i in range(output_len - 1):
        synchronize(device)
575
576
577
578
579
580
        profiler = None
        if enable_profile_decode and i == profile_step_of_interest:
            profiler = start_profile(
                profiler_activities,
                profile_record_shapes=profile_record_shapes,
                rank_print=rank_print,
581
582
            )

583
        tic = time.perf_counter()
584
585
        next_token_ids, _ = decode(next_token_ids, batch, model_runner)
        synchronize(device)
586
        latency = time.perf_counter() - tic
587
588
589
590
591
592
593
594
595
596
597
598
599
600

        if enable_profile_decode and i == profile_step_of_interest:
            trace_filename = _create_torch_profiler_filename(
                profile_filename_prefix, batch_size, input_len, output_len, "decode"
            )
            stop_profile(
                profiler,
                profiler_activities,
                rank_print=rank_print,
                save_trace=True,
                trace_filename=trace_filename,
                stage="decode",
            )

601
602
603
        tot_latency += latency
        throughput = batch_size / latency
        decode_latencies.append(latency)
604
        if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
605
            rank_print(
606
                f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
            )

    # Record decode timing from 2nd output
    if output_len > 1:
        med_decode_latency = np.median(decode_latencies)
        med_decode_throughput = batch_size / med_decode_latency
        rank_print(
            f"Decode.  median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
        )
        measurement_results["median_decode_latency"] = med_decode_latency
        measurement_results["median_decode_throughput"] = med_decode_throughput

    throughput = (input_len + output_len) * batch_size / tot_latency
    rank_print(
        f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
    )
    measurement_results["total_latency"] = tot_latency
    measurement_results["overall_throughput"] = throughput
    return measurement_results


def latency_test(
    server_args,
    port_args,
    bench_args,
632
    gpu_id,
633
634
    tp_rank,
):
635
636
    initialize_moe_config(server_args)

637
638
    # Set CPU affinity
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
639
640
641
        set_gpu_proc_affinity(
            server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
        )
642

643
644
645
646
647
    # Configure the logger
    configure_logger(server_args, prefix=f" TP{tp_rank}")
    rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

    # Load the model
648
    model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

    # Prepare inputs for warm up
    reqs = prepare_synthetic_inputs_for_latency_test(
        bench_args.batch_size[0], bench_args.input_len[0]
    )

    # Warm up
    rank_print("Warmup ...")
    latency_test_run_once(
        bench_args.run_name,
        model_runner,
        rank_print,
        reqs,
        bench_args.batch_size[0],
        bench_args.input_len[0],
664
        min(32, bench_args.output_len[0]),  # shorter decoding to speed up the warmup
665
        server_args.device,
666
        log_decode_step=0,
667
        profile=False,
668
        profile_record_shapes=False,
669
670
671
672
        profiler_activities=("CPU", "GPU"),
        profile_filename_prefix="",
        profile_stage="all",
        tp_rank=tp_rank,
673
    )
674

675
676
    rank_print("Benchmark ...")

677
678
679
680
    custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
    custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
    custom_input_len = len(custom_inputs)

681
682
683
684
685
    # Run the sweep
    result_list = []
    for bs, il, ol in itertools.product(
        bench_args.batch_size, bench_args.input_len, bench_args.output_len
    ):
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        bs_aligned_inputs = []
        if custom_inputs:
            if custom_input_len == bs:
                bs_aligned_inputs = custom_inputs
            elif custom_input_len > bs:
                rank_print(
                    f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
                    f"Using the first {bs} prompts."
                )
                bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
            else:
                rank_print(
                    f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
                    f"Pad to the desired batch_size with the last prompt."
                )
                bs_aligned_inputs = copy.deepcopy(custom_inputs)
                bs_aligned_inputs.extend(
                    [bs_aligned_inputs[-1]] * (bs - custom_input_len)
                )

        reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
707
708
709
710
711
712
713
714
715
        ret = latency_test_run_once(
            bench_args.run_name,
            model_runner,
            rank_print,
            reqs,
            bs,
            il,
            ol,
            server_args.device,
716
            bench_args.log_decode_step,
717
            bench_args.profile if tp_rank == 0 else None,
718
            bench_args.profile_record_shapes if tp_rank == 0 else None,
719
            bench_args.profiler_activities,
720
            bench_args.profile_filename_prefix,
721
722
            bench_args.profile_stage,
            tp_rank,
723
724
725
726
727
728
729
730
731
732
        )
        if ret is not None:
            result_list.append(ret)

    # Write results in jsonlines format on rank 0.
    if tp_rank == 0 and bench_args.result_filename:
        with open(bench_args.result_filename, "a") as fout:
            for result in result_list:
                fout.write(json.dumps(result) + "\n")

Lianmin Zheng's avatar
Lianmin Zheng committed
733
734
735
    if server_args.tp_size > 1:
        destroy_distributed_environment()

736
737

def main(server_args, bench_args):
Lianmin Zheng's avatar
Lianmin Zheng committed
738
739
    server_args.cuda_graph_max_bs = max(bench_args.batch_size)

740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    _set_envs_and_config(server_args)

    if server_args.model_path:
        if bench_args.correctness_test:
            work_func = correctness_test
        else:
            work_func = latency_test
    else:
        raise ValueError(
            "Provide --model-path for running the tests or "
            "provide --result-filename for plotting the results"
        )

    port_args = PortArgs.init_new(server_args)

    if server_args.tp_size == 1:
756
        work_func(server_args, port_args, bench_args, 0, 0)
757
758
759
    else:
        workers = []
        for tp_rank in range(server_args.tp_size):
760
761
762
763
764
765
766
767
768
769
770
771
772
            with maybe_reindex_device_id(tp_rank) as gpu_id:
                proc = multiprocessing.Process(
                    target=work_func,
                    args=(
                        server_args,
                        port_args,
                        bench_args,
                        gpu_id,
                        tp_rank,
                    ),
                )
                proc.start()
                workers.append(proc)
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795

        for proc in workers:
            proc.join()

        proc.terminate()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    BenchArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    bench_args = BenchArgs.from_cli_args(args)

    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

    try:
        main(server_args, bench_args)
    finally:
Lianmin Zheng's avatar
Lianmin Zheng committed
796
        if server_args.tp_size != 1:
797
            kill_process_tree(os.getpid(), include_parent=False)