benchmark_throughput.py 17.6 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

8
import torch
9
from tqdm import tqdm
10
11
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
12

13
from vllm.engine.arg_utils import EngineArgs
14
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
15
from vllm.utils import FlexibleArgumentParser
16

17
18
19
20
21

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
22
    fixed_output_len: Optional[int],
23
) -> List[Tuple[str, int, int]]:
24
25
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
26

27
28
29
30
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
31
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
32
    # Only keep the first two turns of each conversation.
33
34
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
35

36
37
    # Shuffle the dataset.
    random.shuffle(dataset)
38

39
    # Filter out sequences that are too long or too short
40
    filtered_dataset: List[Tuple[str, int, int]] = []
41
42
43
44
45
46
47
48
49
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
50
        prompt_len = len(prompt_token_ids)
51
52
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
53
54
55
56
57
58
59
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
60

61
    return filtered_dataset
62
63


Woosuk Kwon's avatar
Woosuk Kwon committed
64
def run_vllm(
65
66
    requests: List[Tuple[str, int, int]],
    model: str,
67
    tokenizer: str,
68
    quantization: Optional[str],
69
70
71
72
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
73
    trust_remote_code: bool,
74
    dtype: str,
75
76
    max_model_len: Optional[int],
    enforce_eager: bool,
77
    kv_cache_dtype: str,
78
    quantization_param_path: Optional[str],
79
    device: str,
80
    enable_prefix_caching: bool,
81
82
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
83
    distributed_executor_backend: Optional[str],
84
    gpu_memory_utilization: float = 0.9,
85
86
    num_scheduler_steps: int = 1,
    use_v2_block_manager: bool = False,
87
    download_dir: Optional[str] = None,
88
    load_format: str = EngineArgs.load_format,
89
) -> float:
90
    from vllm import LLM, SamplingParams
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
109
        distributed_executor_backend=distributed_executor_backend,
110
        load_format=load_format,
111
112
        num_scheduler_steps=num_scheduler_steps,
        use_v2_block_manager=use_v2_block_manager,
113
    )
114

Zhuohan Li's avatar
Zhuohan Li committed
115
    # Add the requests to the engine.
116
117
    prompts: List[str] = []
    sampling_params: List[SamplingParams] = []
118
    for prompt, _, output_len in requests:
119
120
121
122
123
124
125
126
127
128
        prompts.append(prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            ))
129

130
    start = time.perf_counter()
131
    llm.generate(prompts, sampling_params, use_tqdm=True)
132
    end = time.perf_counter()
133
134
135
136
137
138
139
140
141
142
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
143
    trust_remote_code: bool,
144
145
) -> float:
    assert not use_beam_search
146
147
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
148
149
150
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
151
152
153
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
154
    start = time.perf_counter()
155
156
157
158
159
160
161
162
163
164
165
166
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
167
168
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
169
170
171
172
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
173
174
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
192
    end = time.perf_counter()
193
194
195
    return end - start


196
197
198
199
200
201
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
202
203
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
204
205
206
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
207
    llm.generate(prompts, max_new_tokens=output_len)
208
    end = time.perf_counter()
209
210
    client = client(model)
    client.terminate_server()
211
212
213
    return end - start


214
215
216
217
218
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
219
220
221
222
223
224
225
226
227
228
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
229

Woosuk Kwon's avatar
Woosuk Kwon committed
230
    if args.backend == "vllm":
231
232
233
234
235
236
237
        elapsed_time = run_vllm(
            requests, args.model, args.tokenizer, args.quantization,
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
            args.trust_remote_code, args.dtype, args.max_model_len,
            args.enforce_eager, args.kv_cache_dtype,
            args.quantization_param_path, args.device,
            args.enable_prefix_caching, args.enable_chunked_prefill,
238
            args.max_num_batched_tokens, args.distributed_executor_backend,
239
240
            args.gpu_memory_utilization, args.num_scheduler_steps,
            args.use_v2_block_manager, args.download_dir, args.load_format)
241
242
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
243
244
245
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
246
247
248
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
249
250
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
251
252
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
255

256
257
258
259
260
261
262
263
264
265
266
267
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

268
269

if __name__ == "__main__":
270
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
271
272
    parser.add_argument("--backend",
                        type=str,
273
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
274
                        default="vllm")
275
276
    parser.add_argument("--dataset",
                        type=str,
277
                        default=None,
278
                        help="Path to the dataset.")
279
280
281
282
283
284
285
286
287
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
288
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
289
    parser.add_argument("--tokenizer", type=str, default=None)
290
291
    parser.add_argument('--quantization',
                        '-q',
292
                        choices=[*QUANTIZATION_METHODS, None],
293
                        default=None)
294
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
295
296
297
    parser.add_argument("--n",
                        type=int,
                        default=1,
298
299
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
300
301
302
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
303
304
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
305
306
307
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
308
                        help="Maximum batch size for HF backend.")
309
310
311
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
312
313
314
315
316
317
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
318
319
320
321
322
323
324
325
326
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
327
328
329
330
331
332
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
333
334
335
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
336
    parser.add_argument(
337
        '--kv-cache-dtype',
338
        type=str,
339
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
340
        default="auto",
341
342
343
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
344
345
346
347
348
349
350
351
352
353
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
354
355
356
    parser.add_argument(
        "--device",
        type=str,
357
358
359
360
        default="auto",
        choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
        help='device type for vLLM execution, supporting CUDA, OpenVINO and '
        'CPU.')
361
362
363
364
365
366
367
368
    parser.add_argument(
        "--num-scheduler-steps",
        type=int,
        default=1,
        help="Maximum number of forward steps per scheduler call.")
    parser.add_argument("--use-v2-block-manager",
                        action='store_true',
                        help="Enable block manager v2.")
369
370
371
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
372
        help="Enable automatic prefix caching for vLLM backend.")
373
374
375
376
377
378
379
380
    parser.add_argument("--enable-chunked-prefill",
                        action='store_true',
                        help="enable chunked prefill for vLLM backend.")
    parser.add_argument('--max-num-batched-tokens',
                        type=int,
                        default=None,
                        help='maximum number of batched tokens per '
                        'iteration')
381
382
383
384
385
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
386
387
388
389
390
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
391
392
393
394
395
396
397
    parser.add_argument(
        '--distributed-executor-backend',
        choices=['ray', 'mp'],
        default=None,
        help='Backend to use for distributed serving. When more than 1 GPU '
        'is used, will be automatically set to "ray" if installed '
        'or "mp" (multiprocessing) otherwise.')
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    parser.add_argument(
        '--load-format',
        type=str,
        default=EngineArgs.load_format,
        choices=[
            'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
            'bitsandbytes'
        ],
        help='The format of the model weights to load.\n\n'
        '* "auto" will try to load the weights in the safetensors format '
        'and fall back to the pytorch bin format if safetensors format '
        'is not available.\n'
        '* "pt" will load the weights in the pytorch bin format.\n'
        '* "safetensors" will load the weights in the safetensors format.\n'
        '* "npcache" will load the weights in pytorch format and store '
        'a numpy cache to speed up the loading.\n'
        '* "dummy" will initialize the weights with random values, '
        'which is mainly for profiling.\n'
        '* "tensorizer" will load the weights using tensorizer from '
        'CoreWeave. See the Tensorize vLLM Model script in the Examples'
        'section for more information.\n'
        '* "bitsandbytes" will load the weights using bitsandbytes '
        'quantization.\n')
421
    args = parser.parse_args()
422
423
424
425
426
427
428
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
429

Woosuk Kwon's avatar
Woosuk Kwon committed
430
    if args.backend == "vllm":
431
432
433
434
435
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
436
437
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
452
    main(args)