benchmark_prioritization.py 6.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Benchmark offline prioritization."""
import argparse
4
import dataclasses
5
6
7
import json
import random
import time
8
from typing import Optional
9
10
11

from transformers import AutoTokenizer, PreTrainedTokenizerBase

12
13
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
14
15


16
17
18
19
20
#Select a equi-probable random priority
def get_random_flag():
    return 0 if random.random() < 0.5 else 1


21
22
23
24
25
def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int],
26
) -> list[tuple[str, int, int, int]]:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
43
    filtered_dataset: list[tuple[str, int, int]] = []
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue

63
        priority = get_random_flag()
64
65
66
67
68
69
70

        filtered_dataset.append((prompt, prompt_len, output_len, priority))

    return filtered_dataset


def run_vllm(
71
    requests: list[tuple[str, int, int]],
72
    n: int,
73
    engine_args: EngineArgs,
74
    disable_detokenize: bool = False,
75
76
) -> float:
    from vllm import LLM, SamplingParams
77
    llm = LLM(**dataclasses.asdict(engine_args))
78

79
80
81
82
83
84
    assert all(
        llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of"
            " input_len and output_len for all requests.")

85
86
87
88
89
90
91
92
93
94
    # Add the requests to the engine.
    prompts = []
    sampling_params = []
    priority = []
    for prompt, _, output_len, _priority in requests:
        prompts.append(prompt)
        priority.append(_priority)
        sampling_params.append(
            SamplingParams(
                n=n,
95
                temperature=1.0,
96
97
98
                top_p=1.0,
                ignore_eos=True,
                max_tokens=output_len,
99
                detokenize=not disable_detokenize,
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            ))

    start = time.perf_counter()
    llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
    end = time.perf_counter()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
118
119
        requests = [(prompt, args.input_len, args.output_len,
                     get_random_flag()) for _ in range(args.num_prompts)]
120
121
122
123
124
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)

    if args.backend == "vllm":
125
        elapsed_time = run_vllm(requests, args.n,
126
127
                                EngineArgs.from_cli_args(args),
                                args.disable_detokenize)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len, priority in requests)
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)


if __name__ == "__main__":
149
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf", "mii"],
                        default="vllm")
    parser.add_argument("--dataset",
                        type=str,
                        default=None,
                        help="Path to the dataset.")
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--num-prompts",
                        type=int,
                        default=200,
                        help="Number of prompts to process.")
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
180
181
182
183
184
185
    parser.add_argument(
        '--disable-detokenize',
        action='store_true',
        help=("Do not detokenize responses (i.e. do not include "
              "detokenization time in the latency measurement)"),
    )
186

187
    parser = EngineArgs.add_cli_args(parser)
188
189
190
191
192
193
194
195
196
197
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None

    main(args)