benchmark_latency.py 6.41 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark the latency of processing a single batch of requests."""
3

4
import argparse
5
import dataclasses
6
import json
7
import os
8
import time
9
from pathlib import Path
10
from typing import Any, Optional
11
12
13

import numpy as np
import torch
14
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
15
from tqdm import tqdm
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm import LLM, SamplingParams
18
from vllm.engine.arg_utils import EngineArgs
19
from vllm.inputs import PromptType
20
from vllm.sampling_params import BeamSearchParams
21
from vllm.utils import FlexibleArgumentParser
22
23


24
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
25
                                     results: dict[str, Any]) -> None:
26
27
28
29
30
31
32
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={"latency": results["latencies"]},
        extra_info={k: results[k]
                    for k in ["avg_latency", "percentiles"]})
    if pt_records:
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
33
        write_to_json(pt_file, pt_records)
34
35


36
def main(args: argparse.Namespace):
37
38
    print(args)

39
40
    engine_args = EngineArgs.from_cli_args(args)

41
    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
42
    # the engine will automatically process the request in multiple batches.
43
    llm = LLM(**dataclasses.asdict(engine_args))
44
    assert llm.llm_engine.model_config.max_model_len >= (
45
46
47
        args.input_len +
        args.output_len), ("Please ensure that max_model_len is greater than"
                           " the sum of input_len and output_len.")
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
    sampling_params = SamplingParams(
        n=args.n,
51
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
52
        top_p=1.0,
53
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        max_tokens=args.output_len,
55
        detokenize=not args.disable_detokenize,
Woosuk Kwon's avatar
Woosuk Kwon committed
56
    )
57
    print(sampling_params)
58
59
60
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
61
    dummy_prompts: list[PromptType] = [{
62
63
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
64

65
66
67
68
69
70
71
72
73
74
75
76
    def llm_generate():
        if not args.use_beam_search:
            llm.generate(dummy_prompts,
                         sampling_params=sampling_params,
                         use_tqdm=False)
        else:
            llm.beam_search(
                dummy_prompts,
                BeamSearchParams(
                    beam_width=args.n,
                    max_tokens=args.output_len,
                    ignore_eos=True,
77
78
                ),
            )
79

80
81
82
83
84
85
86
87
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
88
89
                        str(profile_dir)),
            ) as p:
90
                llm_generate()
91
            print(p.key_averages().table(sort_by="self_cuda_time_total"))
92
93
        else:
            start_time = time.perf_counter()
94
            llm_generate()
95
96
97
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
98

99
    print("Warming up...")
100
101
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
102

103
    if args.profile:
104
105
        profile_dir = args.profile_result_dir
        if not profile_dir:
106
107
            profile_dir = (Path(".") / "vllm_benchmark_result" /
                           f"latency_result_{time.time()}")
108
        print(f"Profiling (results will be saved to '{profile_dir}')...")
109
        run_to_completion(profile_dir=profile_dir)
110
111
        return

112
113
    # Benchmark.
    latencies = []
114
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
115
        latencies.append(run_to_completion(profile_dir=None))
116
    latencies = np.array(latencies)
117
    percentages = [10, 25, 50, 75, 90, 99]
118
    percentiles = np.percentile(latencies, percentages)
119
    print(f"Avg latency: {np.mean(latencies)} seconds")
120
    for percentage, percentile in zip(percentages, percentiles):
121
        print(f"{percentage}% percentile latency: {percentile} seconds")
122

123
124
125
126
127
128
129
130
131
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
132
        save_to_pytorch_benchmark_format(args, results)
133

134

135
if __name__ == "__main__":
136
    parser = FlexibleArgumentParser(
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        description="Benchmark the latency of processing a single batch of "
        "requests till completion.")
    parser.add_argument("--input-len", type=int, default=32)
    parser.add_argument("--output-len", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument(
        "--n",
        type=int,
        default=1,
        help="Number of generated sequences per prompt.",
    )
    parser.add_argument("--use-beam-search", action="store_true")
    parser.add_argument(
        "--num-iters-warmup",
        type=int,
        default=10,
        help="Number of iterations to run for warmup.",
    )
    parser.add_argument("--num-iters",
156
                        type=int,
157
                        default=30,
158
                        help="Number of iterations to run.")
159
    parser.add_argument(
160
161
162
163
        "--profile",
        action="store_true",
        help="profile the generation process of a single batch",
    )
164
    parser.add_argument(
165
        "--profile-result-dir",
166
167
        type=str,
        default=None,
168
169
170
        help=("path to save the pytorch profiler output. Can be visualized "
              "with ui.perfetto.dev or Tensorboard."),
    )
171
    parser.add_argument(
172
        "--output-json",
173
174
        type=str,
        default=None,
175
176
        help="Path to save the latency results in JSON format.",
    )
177
178
179
180
181
182
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=("Do not detokenize responses (i.e. do not include "
              "detokenization time in the latency measurement)"),
    )
183
184

    parser = EngineArgs.add_cli_args(parser)
185
186
    args = parser.parse_args()
    main(args)