benchmark_throughput.py 19.2 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

zhuwenwen's avatar
zhuwenwen committed
8
import numpy as np
9
import torch
10
from tqdm import tqdm
11
12
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
13

14
from vllm.engine.arg_utils import EngineArgs
15
from vllm.inputs import PromptInputs
16
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
17
from vllm.utils import FlexibleArgumentParser
18

19
20
21
22
23

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
24
    fixed_output_len: Optional[int],
25
) -> List[Tuple[str, int, int]]:
26
27
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
28

29
30
31
32
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
33
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
34
    # Only keep the first two turns of each conversation.
35
36
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
37

38
39
    # Shuffle the dataset.
    random.shuffle(dataset)
40

41
    # Filter out sequences that are too long or too short
42
    filtered_dataset: List[Tuple[str, int, int]] = []
43
44
45
46
47
48
49
50
51
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
52
        prompt_len = len(prompt_token_ids)
53
54
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
55
56
57
58
59
60
61
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
62

63
    return filtered_dataset
64
65


Woosuk Kwon's avatar
Woosuk Kwon committed
66
def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
67
    warmup_requests: List[Tuple[str, int, int]],
68
69
    requests: List[Tuple[str, int, int]],
    model: str,
70
    tokenizer: str,
71
    quantization: Optional[str],
72
73
74
75
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
76
    trust_remote_code: bool,
77
    dtype: str,
78
79
    max_model_len: Optional[int],
    enforce_eager: bool,
80
    kv_cache_dtype: str,
81
    quantization_param_path: Optional[str],
82
    device: str,
83
    enable_prefix_caching: bool,
84
85
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
86
    distributed_executor_backend: Optional[str],
87
    gpu_memory_utilization: float = 0.9,
88
    download_dir: Optional[str] = None,
89
    load_format: str = EngineArgs.load_format,
90
) -> float:
91
    from vllm import LLM, SamplingParams
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
110
        distributed_executor_backend=distributed_executor_backend,
111
        load_format=load_format,
112
    )
113

Zhuohan Li's avatar
Zhuohan Li committed
114
    # Add the requests to the engine.
115
116
    prompts: List[str] = []
    sampling_params: List[SamplingParams] = []
117
    for prompt, _, output_len in requests:
118
119
120
121
122
123
124
125
126
127
        prompts.append(prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            ))
128

zhuwenwen's avatar
zhuwenwen committed
129
    # warmup
zhuwenwen's avatar
zhuwenwen committed
130
131
132
133
134
135
136
137
138
139
140
141
142
    warmup_prompts = []
    warmup_sampling_params = []
    for prompt, _, output_len in warmup_requests:
        warmup_prompts.append(prompt)
        warmup_sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            ))
143
        
zhuwenwen's avatar
zhuwenwen committed
144
145
    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
146
        llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
147
148
149
150
    
    # dummy_prompt_token_ids = np.random.randint(10000,
    #                                            size=(args.num_prompts,
    #                                                  args.input_len))
151
    # dummy_inputs: List[PromptInputs] = [{
zhuwenwen's avatar
zhuwenwen committed
152
153
    #     "prompt_token_ids": batch
    # } for batch in dummy_prompt_token_ids.tolist()]
zhuwenwen's avatar
zhuwenwen committed
154

zhuwenwen's avatar
zhuwenwen committed
155
156
157
158
    # def run_to_completion():
    #     llm.generate(dummy_inputs,
    #                     sampling_params=sampling_params,
    #                     use_tqdm=False)
159

zhuwenwen's avatar
zhuwenwen committed
160
161
162
    # print("Warming up...")
    # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
    #     run_to_completion()
zhuwenwen's avatar
zhuwenwen committed
163
    
164
    start = time.perf_counter()
165
    llm.generate(prompts, sampling_params, use_tqdm=True)
166
    end = time.perf_counter()
167
168
169
170
171
172
173
174
175
176
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
177
    trust_remote_code: bool,
178
179
) -> float:
    assert not use_beam_search
180
181
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
182
183
184
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
185
186
187
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
188
    start = time.perf_counter()
189
190
191
192
193
194
195
196
197
198
199
200
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
201
202
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
203
204
205
206
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
207
208
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
226
    end = time.perf_counter()
227
228
229
    return end - start


230
231
232
233
234
235
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
236
237
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
238
239
240
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
241
    llm.generate(prompts, max_new_tokens=output_len)
242
    end = time.perf_counter()
243
244
    client = client(model)
    client.terminate_server()
245
246
247
    return end - start


248
249
250
251
252
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
253
254
255
256
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
        warmup_prompt = "hi" * 10
        warmup_requests = [(warmup_prompt, 10, 10)
                    for _ in range(1)]
        
261
262
263
264
265
266
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
267

Woosuk Kwon's avatar
Woosuk Kwon committed
268
    if args.backend == "vllm":
269
        elapsed_time = run_vllm(
zhuwenwen's avatar
zhuwenwen committed
270
            warmup_requests, requests, args.model, args.tokenizer, args.quantization,
271
272
273
274
275
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
            args.trust_remote_code, args.dtype, args.max_model_len,
            args.enforce_eager, args.kv_cache_dtype,
            args.quantization_param_path, args.device,
            args.enable_prefix_caching, args.enable_chunked_prefill,
276
            args.max_num_batched_tokens, args.distributed_executor_backend,
277
            args.gpu_memory_utilization, args.download_dir, args.load_format)
278
279
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
280
281
282
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
283
284
285
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
286
287
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
288
289
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
zhuwenwen's avatar
zhuwenwen committed
290
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
Woosuk Kwon's avatar
Woosuk Kwon committed
291
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
zhuwenwen's avatar
zhuwenwen committed
292
293
294
295
296
297
298
299
300
    
    # if args.dataset is None:
    #     total_out_tokens = args.output_len * args.num_prompts
    # else:
    #     total_out_tokens = sum(output_len for _, _, output_len in requests) 
    # print(f"Latency: {elapsed_time:.2f} s")
    # print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
    #       f"{total_num_tokens / elapsed_time:.2f} tokens/s")
    # print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
zhuwenwen's avatar
zhuwenwen committed
301

302

303
304
305
306
307
308
309
310
311
312
313
314
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

315
316

if __name__ == "__main__":
317
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
318
319
    parser.add_argument("--backend",
                        type=str,
320
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
321
                        default="vllm")
322
323
    parser.add_argument("--dataset",
                        type=str,
324
                        default=None,
325
                        help="Path to the dataset.")
326
327
328
329
330
331
332
333
334
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
335
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
336
    parser.add_argument("--tokenizer", type=str, default=None)
337
338
    parser.add_argument('--quantization',
                        '-q',
339
                        choices=[*QUANTIZATION_METHODS, None],
340
                        default=None)
341
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
342
343
344
    parser.add_argument("--n",
                        type=int,
                        default=1,
345
346
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
zhuwenwen's avatar
zhuwenwen committed
347
348
349
350
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
351
352
353
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
354
355
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
356
357
358
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
359
                        help="Maximum batch size for HF backend.")
360
361
362
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
363
364
365
366
367
368
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
369
370
371
372
373
374
375
376
377
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
378
379
380
381
382
383
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
384
385
386
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
387
    parser.add_argument(
388
        '--kv-cache-dtype',
389
        type=str,
390
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
391
        default="auto",
392
393
394
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
395
396
397
398
399
400
401
402
403
404
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
405
406
407
    parser.add_argument(
        "--device",
        type=str,
408
409
410
411
        default="auto",
        choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
        help='device type for vLLM execution, supporting CUDA, OpenVINO and '
        'CPU.')
412
413
414
415
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
        help="enable automatic prefix caching for vLLM backend.")
416
417
418
419
420
421
422
423
    parser.add_argument("--enable-chunked-prefill",
                        action='store_true',
                        help="enable chunked prefill for vLLM backend.")
    parser.add_argument('--max-num-batched-tokens',
                        type=int,
                        default=None,
                        help='maximum number of batched tokens per '
                        'iteration')
424
425
426
427
428
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
429
430
431
432
433
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
434
435
436
437
438
439
440
    parser.add_argument(
        '--distributed-executor-backend',
        choices=['ray', 'mp'],
        default=None,
        help='Backend to use for distributed serving. When more than 1 GPU '
        'is used, will be automatically set to "ray" if installed '
        'or "mp" (multiprocessing) otherwise.')
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    parser.add_argument(
        '--load-format',
        type=str,
        default=EngineArgs.load_format,
        choices=[
            'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
            'bitsandbytes'
        ],
        help='The format of the model weights to load.\n\n'
        '* "auto" will try to load the weights in the safetensors format '
        'and fall back to the pytorch bin format if safetensors format '
        'is not available.\n'
        '* "pt" will load the weights in the pytorch bin format.\n'
        '* "safetensors" will load the weights in the safetensors format.\n'
        '* "npcache" will load the weights in pytorch format and store '
        'a numpy cache to speed up the loading.\n'
        '* "dummy" will initialize the weights with random values, '
        'which is mainly for profiling.\n'
        '* "tensorizer" will load the weights using tensorizer from '
        'CoreWeave. See the Tensorize vLLM Model script in the Examples'
        'section for more information.\n'
        '* "bitsandbytes" will load the weights using bitsandbytes '
        'quantization.\n')
464
    args = parser.parse_args()
465
466
467
468
469
470
471
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
472

Woosuk Kwon's avatar
Woosuk Kwon committed
473
    if args.backend == "vllm":
474
475
476
477
478
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
479
480
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
495
    main(args)