benchmark_prioritization.py 6.59 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
zhuwenwen's avatar
zhuwenwen committed
3
"""Benchmark offline prioritization."""
4

zhuwenwen's avatar
zhuwenwen committed
5
import argparse
6
import dataclasses
zhuwenwen's avatar
zhuwenwen committed
7
8
9
import json
import random
import time
10
from typing import Optional
zhuwenwen's avatar
zhuwenwen committed
11

zhuwenwen's avatar
zhuwenwen committed
12
from transformers import AutoTokenizer, PreTrainedTokenizerBase
zhuwenwen's avatar
zhuwenwen committed
13

14
15
16
17
18
19
20
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


# Select a equi-probable random priority
def get_random_flag():
    return 0 if random.random() < 0.5 else 1
zhuwenwen's avatar
zhuwenwen committed
21
22
23
24
25
26
27


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int],
28
) -> list[tuple[str, int, int, int]]:
zhuwenwen's avatar
zhuwenwen committed
29
30
31
32
33
34
35
36
37
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
38
39
40
41
    dataset = [
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
        for data in dataset
    ]
zhuwenwen's avatar
zhuwenwen committed
42
43
44
45
46

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
47
    filtered_dataset: list[tuple[str, int, int]] = []
zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
52
53
54
55
56
57
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
58
59
60
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
zhuwenwen's avatar
zhuwenwen committed
61
62
63
64
65
66
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
zhuwenwen's avatar
zhuwenwen committed
67

68
        priority = get_random_flag()
zhuwenwen's avatar
zhuwenwen committed
69
70

        filtered_dataset.append((prompt, prompt_len, output_len, priority))
zhuwenwen's avatar
zhuwenwen committed
71
72
73
74
75

    return filtered_dataset


def run_vllm(
76
    requests: list[tuple[str, int, int]],
zhuwenwen's avatar
zhuwenwen committed
77
    n: int,
78
79
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
zhuwenwen's avatar
zhuwenwen committed
80
81
) -> float:
    from vllm import LLM, SamplingParams
82
83
84
85
86
87
88
89
90

    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
        llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " input_len and output_len for all requests."
zhuwenwen's avatar
zhuwenwen committed
91
92
93
94
95
    )

    # Add the requests to the engine.
    prompts = []
    sampling_params = []
zhuwenwen's avatar
zhuwenwen committed
96
97
    priority = []
    for prompt, _, output_len, _priority in requests:
zhuwenwen's avatar
zhuwenwen committed
98
        prompts.append(prompt)
zhuwenwen's avatar
zhuwenwen committed
99
        priority.append(_priority)
zhuwenwen's avatar
zhuwenwen committed
100
101
102
        sampling_params.append(
            SamplingParams(
                n=n,
103
                temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
104
105
106
                top_p=1.0,
                ignore_eos=True,
                max_tokens=output_len,
107
108
109
                detokenize=not disable_detokenize,
            )
        )
zhuwenwen's avatar
zhuwenwen committed
110
111

    start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
112
    llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
113
114
115
116
117
118
119
120
121
122
    end = time.perf_counter()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
123
124
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
zhuwenwen's avatar
zhuwenwen committed
125
126
127
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
128
129
130
131
        requests = [
            (prompt, args.input_len, args.output_len, get_random_flag())
            for _ in range(args.num_prompts)
        ]
zhuwenwen's avatar
zhuwenwen committed
132
    else:
133
134
135
        requests = sample_requests(
            args.dataset, args.num_prompts, tokenizer, args.output_len
        )
zhuwenwen's avatar
zhuwenwen committed
136
137
138

    if args.backend == "vllm":
        elapsed_time = run_vllm(
139
140
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
zhuwenwen's avatar
zhuwenwen committed
141
142
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
143
144
145
146
147
148
149
    total_num_tokens = sum(
        prompt_len + output_len for _, prompt_len, output_len, priority in requests
    )
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} tokens/s"
    )
zhuwenwen's avatar
zhuwenwen committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)


164
165
def create_argument_parser():
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
zhuwenwen's avatar
zhuwenwen committed
166
    parser.add_argument(
167
168
        "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
    )
zhuwenwen's avatar
zhuwenwen committed
169
    parser.add_argument(
170
171
        "--dataset", type=str, default=None, help="Path to the dataset."
    )
zhuwenwen's avatar
zhuwenwen committed
172
    parser.add_argument(
173
174
175
176
177
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
zhuwenwen's avatar
zhuwenwen committed
178
    parser.add_argument(
179
180
        "--output-len",
        type=int,
zhuwenwen's avatar
zhuwenwen committed
181
        default=None,
182
183
184
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
zhuwenwen's avatar
zhuwenwen committed
185
    parser.add_argument(
186
187
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
zhuwenwen's avatar
zhuwenwen committed
188
    parser.add_argument(
189
190
        "--num-prompts", type=int, default=200, help="Number of prompts to process."
    )
zhuwenwen's avatar
zhuwenwen committed
191
    parser.add_argument(
192
        "--output-json",
zhuwenwen's avatar
zhuwenwen committed
193
194
        type=str,
        default=None,
195
196
197
198
199
200
201
202
203
204
205
206
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=(
            "Do not detokenize responses (i.e. do not include "
            "detokenization time in the latency measurement)"
        ),
    )

    parser = EngineArgs.add_cli_args(parser)
zhuwenwen's avatar
zhuwenwen committed
207

208
209
210
211
212
    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
zhuwenwen's avatar
zhuwenwen committed
213
214
215
216
217
218
219
220
221
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None

222
    main(args)