benchmark_throughput.py 14.8 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

8
import torch
9
from tqdm import tqdm
10
11
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
12

13
14
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

15
16
17
18
19

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
20
    fixed_output_len: Optional[int],
21
) -> List[Tuple[str, int, int]]:
22
23
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
24

25
26
27
28
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
29
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
30
    # Only keep the first two turns of each conversation.
31
32
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
33

34
35
    # Shuffle the dataset.
    random.shuffle(dataset)
36

37
    # Filter out sequences that are too long or too short
38
    filtered_dataset: List[Tuple[str, int, int]] = []
39
40
41
42
43
44
45
46
47
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
48
        prompt_len = len(prompt_token_ids)
49
50
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
51
52
53
54
55
56
57
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
58

59
    return filtered_dataset
60
61


Woosuk Kwon's avatar
Woosuk Kwon committed
62
def run_vllm(
63
64
    requests: List[Tuple[str, int, int]],
    model: str,
65
    tokenizer: str,
66
    quantization: Optional[str],
67
68
69
70
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
71
    trust_remote_code: bool,
72
    dtype: str,
73
74
    max_model_len: Optional[int],
    enforce_eager: bool,
75
    kv_cache_dtype: str,
76
    quantization_param_path: Optional[str],
77
    device: str,
78
    enable_prefix_caching: bool,
79
80
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
81
    gpu_memory_utilization: float = 0.9,
82
    download_dir: Optional[str] = None,
83
) -> float:
84
    from vllm import LLM, SamplingParams
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
    )
104

Zhuohan Li's avatar
Zhuohan Li committed
105
    # Add the requests to the engine.
106
    for prompt, _, output_len in requests:
107
        sampling_params = SamplingParams(
108
109
            n=n,
            temperature=0.0 if use_beam_search else 1.0,
110
            top_p=1.0,
111
            use_beam_search=use_beam_search,
112
113
114
115
116
            ignore_eos=True,
            max_tokens=output_len,
        )
        # FIXME(woosuk): Do not use internal method.
        llm._add_request(
117
            prompt=prompt,
118
            prompt_token_ids=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
            sampling_params=sampling_params,
120
121
        )

122
    start = time.perf_counter()
Zhuofan's avatar
Zhuofan committed
123
    # FIXME(woosuk): Do not use internal method.
Zhuohan Li's avatar
Zhuohan Li committed
124
    llm._run_engine(use_tqdm=True)
125
    end = time.perf_counter()
126
127
128
129
130
131
132
133
134
135
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
136
    trust_remote_code: bool,
137
138
) -> float:
    assert not use_beam_search
139
140
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
141
142
143
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
144
145
146
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
147
    start = time.perf_counter()
148
149
150
151
152
153
154
155
156
157
158
159
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
160
161
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
162
163
164
165
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
166
167
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
185
    end = time.perf_counter()
186
187
188
    return end - start


189
190
191
192
193
194
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
195
196
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
197
198
199
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
200
    llm.generate(prompts, max_new_tokens=output_len)
201
    end = time.perf_counter()
202
203
    client = client(model)
    client.terminate_server()
204
205
206
    return end - start


207
208
209
210
211
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
212
213
214
215
216
217
218
219
220
221
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
222

Woosuk Kwon's avatar
Woosuk Kwon committed
223
    if args.backend == "vllm":
224
225
226
227
228
229
230
231
232
        elapsed_time = run_vllm(
            requests, args.model, args.tokenizer, args.quantization,
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
            args.trust_remote_code, args.dtype, args.max_model_len,
            args.enforce_eager, args.kv_cache_dtype,
            args.quantization_param_path, args.device,
            args.enable_prefix_caching, args.enable_chunked_prefill,
            args.max_num_batched_tokens, args.gpu_memory_utilization,
            args.download_dir)
233
234
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
235
236
237
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
238
239
240
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
241
242
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
243
244
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
247
248
249
250


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
251
252
    parser.add_argument("--backend",
                        type=str,
253
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
254
                        default="vllm")
255
256
    parser.add_argument("--dataset",
                        type=str,
257
                        default=None,
258
                        help="Path to the dataset.")
259
260
261
262
263
264
265
266
267
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
268
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
269
    parser.add_argument("--tokenizer", type=str, default=None)
270
271
    parser.add_argument('--quantization',
                        '-q',
272
                        choices=[*QUANTIZATION_METHODS, None],
273
                        default=None)
274
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
275
276
277
    parser.add_argument("--n",
                        type=int,
                        default=1,
278
279
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
280
281
282
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
283
284
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
285
286
287
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
288
                        help="Maximum batch size for HF backend.")
289
290
291
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
292
293
294
295
296
297
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
298
299
300
301
302
303
304
305
306
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
307
308
309
310
311
312
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
313
314
315
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
316
317
318
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
319
        choices=["auto", "fp8"],
320
321
        default="auto",
        help=
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        'Data type for kv cache storage. If "auto", will use model data type. '
        'FP8_E5M2 (without scaling) is only supported on cuda version greater '
        'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
        'common inference criteria.')
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
336
337
338
339
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
340
341
        choices=["cuda", "cpu"],
        help='device type for vLLM execution, supporting CUDA and CPU.')
342
343
344
345
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
        help="enable automatic prefix caching for vLLM backend.")
346
347
348
349
350
351
352
353
    parser.add_argument("--enable-chunked-prefill",
                        action='store_true',
                        help="enable chunked prefill for vLLM backend.")
    parser.add_argument('--max-num-batched-tokens',
                        type=int,
                        default=None,
                        help='maximum number of batched tokens per '
                        'iteration')
354
355
356
357
358
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
359
    args = parser.parse_args()
360
361
362
363
364
365
366
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
367

Woosuk Kwon's avatar
Woosuk Kwon committed
368
    if args.backend == "vllm":
369
370
371
372
373
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
374
375
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
390
    main(args)