"vscode:/vscode.git/clone" did not exist on "e2ed49f81feba446d4f56eb568b494767f803b83"
benchmark_latency.py 6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
zhuwenwen's avatar
zhuwenwen committed
3
"""Benchmark the latency of processing a single batch of requests."""
4

zhuwenwen's avatar
zhuwenwen committed
5
import argparse
6
import dataclasses
zhuwenwen's avatar
zhuwenwen committed
7
import json
8
import os
zhuwenwen's avatar
zhuwenwen committed
9
import time
10
from typing import Any, Optional
zhuwenwen's avatar
zhuwenwen committed
11
12
13
14

import numpy as np
from tqdm import tqdm

15
16
import vllm.envs as envs
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
zhuwenwen's avatar
zhuwenwen committed
17
from vllm import LLM, SamplingParams
18
19
20
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
laibao's avatar
laibao committed
21
from vllm.utils import FlexibleArgumentParser
zhuwenwen's avatar
zhuwenwen committed
22
23


24
25
26
27
28
29
30
31
32
33
34
35
36
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={"latency": results["latencies"]},
        extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
    )
    if pt_records:
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
        write_to_json(pt_file, pt_records)


zhuwenwen's avatar
zhuwenwen committed
37
38
39
def main(args: argparse.Namespace):
    print(args)

40
41
    engine_args = EngineArgs.from_cli_args(args)

zhuwenwen's avatar
zhuwenwen committed
42
43
    # NOTE(woosuk): If the request cannot be processed in a single batch,
    # the engine will automatically process the request in multiple batches.
44
45
46
47
48
49
    llm = LLM(**dataclasses.asdict(engine_args))
    assert llm.llm_engine.model_config.max_model_len >= (
        args.input_len + args.output_len
    ), (
        "Please ensure that max_model_len is greater than"
        " the sum of input_len and output_len."
laibao's avatar
laibao committed
50
    )
zhuwenwen's avatar
zhuwenwen committed
51
52
53

    sampling_params = SamplingParams(
        n=args.n,
54
        temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
55
56
57
        top_p=1.0,
        ignore_eos=True,
        max_tokens=args.output_len,
58
        detokenize=not args.disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
59
60
    )
    print(sampling_params)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    dummy_prompt_token_ids = np.random.randint(
        10000, size=(args.batch_size, args.input_len)
    )
    dummy_prompts: list[PromptType] = [
        {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
    ]

    def llm_generate():
        if not args.use_beam_search:
            llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
        else:
            llm.beam_search(
                dummy_prompts,
                BeamSearchParams(
                    beam_width=args.n,
                    max_tokens=args.output_len,
                    ignore_eos=True,
                ),
            )
zhuwenwen's avatar
zhuwenwen committed
80
81
82

    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
83
84
85
            llm.start_profile()
            llm_generate()
            llm.stop_profile()
zhuwenwen's avatar
zhuwenwen committed
86
87
        else:
            start_time = time.perf_counter()
88
            llm_generate()
zhuwenwen's avatar
zhuwenwen committed
89
90
91
92
93
94
95
96
97
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency

    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)

    if args.profile:
98
        profile_dir = envs.VLLM_TORCH_PROFILER_DIR
zhuwenwen's avatar
zhuwenwen committed
99
100
101
102
103
104
105
106
107
        print(f"Profiling (results will be saved to '{profile_dir}')...")
        run_to_completion(profile_dir=profile_dir)
        return

    # Benchmark.
    latencies = []
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
        latencies.append(run_to_completion(profile_dir=None))
    latencies = np.array(latencies)
laibao's avatar
laibao committed
108
    percentages = [10, 25, 50, 75, 90, 99]
zhuwenwen's avatar
zhuwenwen committed
109
    percentiles = np.percentile(latencies, percentages)
110
    print(f"Avg latency: {np.mean(latencies)} seconds")
zhuwenwen's avatar
zhuwenwen committed
111
    for percentage, percentile in zip(percentages, percentiles):
112
        print(f"{percentage}% percentile latency: {percentile} seconds")
zhuwenwen's avatar
zhuwenwen committed
113
114
115
116
117
118
119
120
121
122

    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
123
        save_to_pytorch_benchmark_format(args, results)
zhuwenwen's avatar
zhuwenwen committed
124
125


126
def create_argument_parser():
laibao's avatar
laibao committed
127
    parser = FlexibleArgumentParser(
128
129
130
131
132
133
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("--input-len", type=int, default=32)
    parser.add_argument("--output-len", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=8)
laibao's avatar
laibao committed
134
    parser.add_argument(
135
        "--n",
laibao's avatar
laibao committed
136
        type=int,
137
138
139
140
        default=1,
        help="Number of generated sequences per prompt.",
    )
    parser.add_argument("--use-beam-search", action="store_true")
zhuwenwen's avatar
zhuwenwen committed
141
    parser.add_argument(
142
143
144
145
146
        "--num-iters-warmup",
        type=int,
        default=10,
        help="Number of iterations to run for warmup.",
    )
zhuwenwen's avatar
zhuwenwen committed
147
    parser.add_argument(
148
        "--num-iters", type=int, default=30, help="Number of iterations to run."
zhuwenwen's avatar
zhuwenwen committed
149
150
    )
    parser.add_argument(
151
152
153
154
        "--profile",
        action="store_true",
        help="profile the generation process of a single batch",
    )
laibao's avatar
laibao committed
155
    parser.add_argument(
156
        "--output-json",
laibao's avatar
laibao committed
157
        type=str,
zhuwenwen's avatar
zhuwenwen committed
158
        default=None,
159
160
        help="Path to save the latency results in JSON format.",
    )
laibao's avatar
laibao committed
161
    parser.add_argument(
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        "--disable-detokenize",
        action="store_true",
        help=(
            "Do not detokenize responses (i.e. do not include "
            "detokenization time in the latency measurement)"
        ),
    )

    parser = EngineArgs.add_cli_args(parser)
    # V1 enables prefix caching by default which skews the latency
    # numbers. We need to disable prefix caching by default.
    parser.set_defaults(enable_prefix_caching=False)

    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
zhuwenwen's avatar
zhuwenwen committed
180
    args = parser.parse_args()
181
182
183
184
185
    if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
        raise OSError(
            "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
            "Please set it to a valid path to use torch profiler."
        )
zhuwenwen's avatar
zhuwenwen committed
186
    main(args)