benchmark_prioritization.py 6.59 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
laibao's avatar
laibao committed
3
"""Benchmark offline prioritization."""
4

laibao's avatar
laibao committed
5
import argparse
6
import dataclasses
laibao's avatar
laibao committed
7
8
9
import json
import random
import time
10
from typing import Optional
laibao's avatar
laibao committed
11
12
13

from transformers import AutoTokenizer, PreTrainedTokenizerBase

14
15
16
17
18
19
20
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


# Select a equi-probable random priority
def get_random_flag():
    return 0 if random.random() < 0.5 else 1
laibao's avatar
laibao committed
21
22
23
24
25
26
27


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int],
28
) -> list[tuple[str, int, int, int]]:
laibao's avatar
laibao committed
29
30
31
32
33
34
35
36
37
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
38
39
40
41
    dataset = [
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
        for data in dataset
    ]
laibao's avatar
laibao committed
42
43
44
45
46

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
47
    filtered_dataset: list[tuple[str, int, int]] = []
laibao's avatar
laibao committed
48
49
50
51
52
53
54
55
56
57
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
58
59
60
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
laibao's avatar
laibao committed
61
62
63
64
65
66
67
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue

68
        priority = get_random_flag()
laibao's avatar
laibao committed
69
70
71
72
73
74
75

        filtered_dataset.append((prompt, prompt_len, output_len, priority))

    return filtered_dataset


def run_vllm(
76
    requests: list[tuple[str, int, int]],
laibao's avatar
laibao committed
77
    n: int,
78
79
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
laibao's avatar
laibao committed
80
81
) -> float:
    from vllm import LLM, SamplingParams
82
83
84
85
86
87
88
89
90

    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
        llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " input_len and output_len for all requests."
laibao's avatar
laibao committed
91
92
93
94
95
96
97
98
99
100
101
102
    )

    # Add the requests to the engine.
    prompts = []
    sampling_params = []
    priority = []
    for prompt, _, output_len, _priority in requests:
        prompts.append(prompt)
        priority.append(_priority)
        sampling_params.append(
            SamplingParams(
                n=n,
103
                temperature=1.0,
laibao's avatar
laibao committed
104
105
106
                top_p=1.0,
                ignore_eos=True,
                max_tokens=output_len,
107
108
109
                detokenize=not disable_detokenize,
            )
        )
laibao's avatar
laibao committed
110
111
112
113
114
115
116
117
118
119
120
121
122

    start = time.perf_counter()
    llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
    end = time.perf_counter()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
123
124
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
laibao's avatar
laibao committed
125
126
127
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
128
129
130
131
        requests = [
            (prompt, args.input_len, args.output_len, get_random_flag())
            for _ in range(args.num_prompts)
        ]
laibao's avatar
laibao committed
132
    else:
133
134
135
        requests = sample_requests(
            args.dataset, args.num_prompts, tokenizer, args.output_len
        )
laibao's avatar
laibao committed
136
137
138

    if args.backend == "vllm":
        elapsed_time = run_vllm(
139
140
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
laibao's avatar
laibao committed
141
142
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
143
144
145
146
147
148
149
    total_num_tokens = sum(
        prompt_len + output_len for _, prompt_len, output_len, priority in requests
    )
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} tokens/s"
    )
laibao's avatar
laibao committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)


164
165
def create_argument_parser():
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
laibao's avatar
laibao committed
166
    parser.add_argument(
167
168
        "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
    )
laibao's avatar
laibao committed
169
    parser.add_argument(
170
171
        "--dataset", type=str, default=None, help="Path to the dataset."
    )
laibao's avatar
laibao committed
172
    parser.add_argument(
173
174
175
176
177
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
laibao's avatar
laibao committed
178
    parser.add_argument(
179
180
        "--output-len",
        type=int,
laibao's avatar
laibao committed
181
        default=None,
182
183
184
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
laibao's avatar
laibao committed
185
    parser.add_argument(
186
187
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
laibao's avatar
laibao committed
188
    parser.add_argument(
189
190
        "--num-prompts", type=int, default=200, help="Number of prompts to process."
    )
laibao's avatar
laibao committed
191
    parser.add_argument(
192
        "--output-json",
laibao's avatar
laibao committed
193
194
        type=str,
        default=None,
195
196
197
198
199
200
201
202
203
204
205
206
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=(
            "Do not detokenize responses (i.e. do not include "
            "detokenization time in the latency measurement)"
        ),
    )

    parser = EngineArgs.add_cli_args(parser)
laibao's avatar
laibao committed
207

208
209
210
211
212
    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
laibao's avatar
laibao committed
213
214
215
216
217
218
219
220
221
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None

222
    main(args)