benchmark_latency.py 6.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark the latency of processing a single batch of requests."""
3

4
import argparse
5
import dataclasses
6
import json
7
import os
8
import time
9
from pathlib import Path
10
from typing import Any, Optional
11
12
13

import numpy as np
import torch
14
from tqdm import tqdm
15

16
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm import LLM, SamplingParams
18
from vllm.engine.arg_utils import EngineArgs
19
from vllm.inputs import PromptType
20
from vllm.sampling_params import BeamSearchParams
21
from vllm.utils import FlexibleArgumentParser
22
23


24
25
26
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
27
28
29
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={"latency": results["latencies"]},
30
31
        extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
    )
32
33
    if pt_records:
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
34
        write_to_json(pt_file, pt_records)
35
36


37
def main(args: argparse.Namespace):
38
39
    print(args)

40
41
    engine_args = EngineArgs.from_cli_args(args)

42
    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
43
    # the engine will automatically process the request in multiple batches.
44
    llm = LLM(**dataclasses.asdict(engine_args))
45
    assert llm.llm_engine.model_config.max_model_len >= (
46
47
48
49
50
        args.input_len + args.output_len
    ), (
        "Please ensure that max_model_len is greater than"
        " the sum of input_len and output_len."
    )
51

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
    sampling_params = SamplingParams(
        n=args.n,
54
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
        top_p=1.0,
56
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        max_tokens=args.output_len,
58
        detokenize=not args.disable_detokenize,
Woosuk Kwon's avatar
Woosuk Kwon committed
59
    )
60
    print(sampling_params)
61
62
63
64
65
66
    dummy_prompt_token_ids = np.random.randint(
        10000, size=(args.batch_size, args.input_len)
    )
    dummy_prompts: list[PromptType] = [
        {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
    ]
67

68
69
    def llm_generate():
        if not args.use_beam_search:
70
            llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
71
72
73
74
75
76
77
        else:
            llm.beam_search(
                dummy_prompts,
                BeamSearchParams(
                    beam_width=args.n,
                    max_tokens=args.output_len,
                    ignore_eos=True,
78
79
                ),
            )
80

81
82
83
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
84
85
86
87
88
89
90
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    str(profile_dir)
                ),
91
            ) as p:
92
                llm_generate()
93
            print(p.key_averages().table(sort_by="self_cuda_time_total"))
94
95
        else:
            start_time = time.perf_counter()
96
            llm_generate()
97
98
99
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
100

101
    print("Warming up...")
102
103
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
104

105
    if args.profile:
106
107
        profile_dir = args.profile_result_dir
        if not profile_dir:
108
109
110
            profile_dir = (
                Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
            )
111
        print(f"Profiling (results will be saved to '{profile_dir}')...")
112
        run_to_completion(profile_dir=profile_dir)
113
114
        return

115
116
    # Benchmark.
    latencies = []
117
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
118
        latencies.append(run_to_completion(profile_dir=None))
119
    latencies = np.array(latencies)
120
    percentages = [10, 25, 50, 75, 90, 99]
121
    percentiles = np.percentile(latencies, percentages)
122
    print(f"Avg latency: {np.mean(latencies)} seconds")
123
    for percentage, percentile in zip(percentages, percentiles):
124
        print(f"{percentage}% percentile latency: {percentile} seconds")
125

126
127
128
129
130
131
132
133
134
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
135
        save_to_pytorch_benchmark_format(args, results)
136

137

138
if __name__ == "__main__":
139
    parser = FlexibleArgumentParser(
140
        description="Benchmark the latency of processing a single batch of "
141
142
        "requests till completion."
    )
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    parser.add_argument("--input-len", type=int, default=32)
    parser.add_argument("--output-len", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument(
        "--n",
        type=int,
        default=1,
        help="Number of generated sequences per prompt.",
    )
    parser.add_argument("--use-beam-search", action="store_true")
    parser.add_argument(
        "--num-iters-warmup",
        type=int,
        default=10,
        help="Number of iterations to run for warmup.",
    )
159
160
161
    parser.add_argument(
        "--num-iters", type=int, default=30, help="Number of iterations to run."
    )
162
    parser.add_argument(
163
164
165
166
        "--profile",
        action="store_true",
        help="profile the generation process of a single batch",
    )
167
    parser.add_argument(
168
        "--profile-result-dir",
169
170
        type=str,
        default=None,
171
172
173
174
        help=(
            "path to save the pytorch profiler output. Can be visualized "
            "with ui.perfetto.dev or Tensorboard."
        ),
175
    )
176
    parser.add_argument(
177
        "--output-json",
178
179
        type=str,
        default=None,
180
181
        help="Path to save the latency results in JSON format.",
    )
182
183
184
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
185
186
187
188
        help=(
            "Do not detokenize responses (i.e. do not include "
            "detokenization time in the latency measurement)"
        ),
189
    )
190
191

    parser = EngineArgs.add_cli_args(parser)
192
193
194
    # V1 enables prefix caching by default which skews the latency
    # numbers. We need to disable prefix caching by default.
    parser.set_defaults(enable_prefix_caching=False)
195
196
    args = parser.parse_args()
    main(args)