benchmark_throughput.py 12.1 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

8
import torch
9
10
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
11
12
from tqdm import tqdm

13
14
15
16
17

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
18
    fixed_output_len: Optional[int],
19
) -> List[Tuple[str, int, int]]:
20
21
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
22

23
24
25
26
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
27
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
28
    # Only keep the first two turns of each conversation.
29
30
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
31
32
33
34
35
36
37
38
39

    # Tokenize the prompts and completions.
    prompts = [prompt for prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
40
41
        if fixed_output_len is not None:
            output_len = fixed_output_len
42
43
44
45
46
47
48
49
50
51
52
53
54
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    # Filter out too long sequences.
    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
55
56

    # Sample the requests.
57
    sampled_requests = random.sample(filtered_dataset, num_requests)
58
59
60
    return sampled_requests


Woosuk Kwon's avatar
Woosuk Kwon committed
61
def run_vllm(
62
63
    requests: List[Tuple[str, int, int]],
    model: str,
64
    tokenizer: str,
65
    quantization: Optional[str],
66
67
68
69
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
70
    trust_remote_code: bool,
71
    dtype: str,
72
73
    max_model_len: Optional[int],
    enforce_eager: bool,
74
    kv_cache_dtype: str,
75
) -> float:
76
    from vllm import LLM, SamplingParams
77
    llm = LLM(
78
        model=model,
79
        tokenizer=tokenizer,
80
        quantization=quantization,
81
82
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
83
        trust_remote_code=trust_remote_code,
84
        dtype=dtype,
85
        max_model_len=max_model_len,
86
        enforce_eager=enforce_eager,
87
        kv_cache_dtype=kv_cache_dtype,
88
89
    )

Zhuohan Li's avatar
Zhuohan Li committed
90
    # Add the requests to the engine.
91
    for prompt, _, output_len in requests:
92
        sampling_params = SamplingParams(
93
94
            n=n,
            temperature=0.0 if use_beam_search else 1.0,
95
            top_p=1.0,
96
            use_beam_search=use_beam_search,
97
98
99
100
101
            ignore_eos=True,
            max_tokens=output_len,
        )
        # FIXME(woosuk): Do not use internal method.
        llm._add_request(
102
            prompt=prompt,
103
            prompt_token_ids=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
104
            sampling_params=sampling_params,
105
106
        )

107
    start = time.perf_counter()
Zhuofan's avatar
Zhuofan committed
108
    # FIXME(woosuk): Do not use internal method.
Zhuohan Li's avatar
Zhuohan Li committed
109
    llm._run_engine(use_tqdm=True)
110
    end = time.perf_counter()
111
112
113
114
115
116
117
118
119
120
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
121
    trust_remote_code: bool,
122
123
) -> float:
    assert not use_beam_search
124
125
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
126
127
128
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
129
130
131
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
132
    start = time.perf_counter()
133
134
135
136
137
138
139
140
141
142
143
144
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
145
146
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
147
148
149
150
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
151
152
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
170
    end = time.perf_counter()
171
172
173
    return end - start


174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import pipeline
    llm = pipeline(model, tensor_parallel=tensor_parallel_size)
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
    llm(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    return end - start


190
191
192
193
194
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
195
196
197
198
199
200
201
202
203
204
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
205

Woosuk Kwon's avatar
Woosuk Kwon committed
206
    if args.backend == "vllm":
207
208
209
        elapsed_time = run_vllm(requests, args.model, args.tokenizer,
                                args.quantization, args.tensor_parallel_size,
                                args.seed, args.n, args.use_beam_search,
210
                                args.trust_remote_code, args.dtype,
211
212
                                args.max_model_len, args.enforce_eager,
                                args.kv_cache_dtype)
213
214
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
215
216
217
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
218
219
220
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
221
222
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
223
224
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
227
228
229
230


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
231
232
    parser.add_argument("--backend",
                        type=str,
233
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
234
                        default="vllm")
235
236
    parser.add_argument("--dataset",
                        type=str,
237
                        default=None,
238
                        help="Path to the dataset.")
239
240
241
242
243
244
245
246
247
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
248
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
249
    parser.add_argument("--tokenizer", type=str, default=None)
250
251
    parser.add_argument('--quantization',
                        '-q',
CHU Tianxiang's avatar
CHU Tianxiang committed
252
                        choices=['awq', 'gptq', 'squeezellm', None],
253
                        default=None)
254
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
255
256
257
    parser.add_argument("--n",
                        type=int,
                        default=1,
258
259
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
260
261
262
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
263
264
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
265
266
267
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
268
                        help="Maximum batch size for HF backend.")
269
270
271
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
272
273
274
275
276
277
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
278
279
280
281
282
283
284
285
286
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
287
288
289
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
290
291
292
293
294
295
296
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
        choices=["auto", "fp8_e5m2"],
        default="auto",
        help=
        'Data type for kv cache storage. If "auto", will use model data type.')
297
    args = parser.parse_args()
298
299
300
301
302
303
304
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
305

Woosuk Kwon's avatar
Woosuk Kwon committed
306
    if args.backend == "vllm":
307
308
309
310
311
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
312
313
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
328
    main(args)