benchmark_throughput.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark offline inference throughput."""
3
import argparse
4
import dataclasses
5
import json
6
import os
7
8
import random
import time
9
import warnings
10
from typing import Any, Optional, Union
11

12
import torch
13
import uvloop
14
15
16
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
                               RandomDataset, SampleRequest, ShareGPTDataset,
                               SonnetDataset, VisionArenaDataset)
17
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
18
from tqdm import tqdm
19
20
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
21

22
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
23
24
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
25
from vllm.inputs import TextPrompt, TokensPrompt
26
from vllm.lora.request import LoRARequest
27
from vllm.outputs import RequestOutput
28
from vllm.sampling_params import BeamSearchParams
29
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
30

31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
def run_vllm(
33
    requests: list[SampleRequest],
34
    n: int,
35
    engine_args: EngineArgs,
36
    disable_detokenize: bool = False,
37
) -> tuple[float, Optional[list[RequestOutput]]]:
38
    from vllm import LLM, SamplingParams
39
    llm = LLM(**dataclasses.asdict(engine_args))
40
41
42
43
44
45
    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests.")
Zhuohan Li's avatar
Zhuohan Li committed
46
    # Add the requests to the engine.
47
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
48
    sampling_params: list[SamplingParams] = []
49
    for request in requests:
50
        prompts.append(
51
52
53
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                       multi_modal_data=request.multi_modal_data)
            if "prompt_token_ids" in request.prompt else \
54
55
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
56
57
58
        sampling_params.append(
            SamplingParams(
                n=n,
59
                temperature=1.0,
60
61
                top_p=1.0,
                ignore_eos=True,
62
                max_tokens=request.expected_output_len,
63
                detokenize=not disable_detokenize,
64
            ))
65
    lora_requests: Optional[list[LoRARequest]] = None
66
67
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
68

69
70
    use_beam_search = False

71
    outputs = None
72
    if not use_beam_search:
73
        start = time.perf_counter()
74
75
76
77
        outputs = llm.generate(prompts,
                               sampling_params,
                               lora_request=lora_requests,
                               use_tqdm=True)
78
79
        end = time.perf_counter()
    else:
80
        assert lora_requests is None, "BeamSearch API does not support LoRA"
81
        prompts = [request.prompt for request in requests]
82
83
        # output_len should be the same for all requests.
        output_len = requests[0][2]
84
85
        for request in requests:
            assert request.expected_output_len == output_len
86
        start = time.perf_counter()
87
88
89
90
91
92
93
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
94
        end = time.perf_counter()
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    return end - start, outputs


def run_vllm_chat(
        requests: list[SampleRequest],
        n: int,
        engine_args: EngineArgs,
        disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of "
            "prompt_len and expected_output_len for all requests.")

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
            ))
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
135
136


137
async def run_vllm_async(
138
    requests: list[SampleRequest],
139
    n: int,
140
    engine_args: AsyncEngineArgs,
141
    disable_frontend_multiprocessing: bool = False,
142
    disable_detokenize: bool = False,
143
144
145
146
147
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:
148
149
150
151
152
153
        assert all(
            llm.model_config.max_model_len >= (request.prompt_len +
                                               request.expected_output_len)
            for request in requests), (
                "Please ensure that max_model_len is greater than the sum of"
                " prompt_len and expected_output_len for all requests.")
154
155

        # Add the requests to the engine.
156
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
157
158
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
159
        for request in requests:
160
            prompts.append(
161
162
163
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                        multi_modal_data=request.multi_modal_data)
                if "prompt_token_ids" in request.prompt else \
164
165
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
166
167
168
            sampling_params.append(
                SamplingParams(
                    n=n,
169
                    temperature=1.0,
170
171
                    top_p=1.0,
                    ignore_eos=True,
172
                    max_tokens=request.expected_output_len,
173
                    detokenize=not disable_detokenize,
174
                ))
175
            lora_requests.append(request.lora_request)
176
177
178

        generators = []
        start = time.perf_counter()
179
180
181
182
183
184
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
185
186
187
188
189
190
191
192
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


193
def run_hf(
194
    requests: list[SampleRequest],
195
196
197
198
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
199
    trust_remote_code: bool,
200
    disable_detokenize: bool = False,
201
) -> float:
202
203
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
204
205
206
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
207
208
209
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
210
    start = time.perf_counter()
211
    batch: list[str] = []
212
213
214
215
216
217
218
219
220
221
222
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
223
224
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
225
226
227
228
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
229
230
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
231
232
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
233
            do_sample=True,
234
235
236
237
238
239
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
240
241
242
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
243
244
245
246
247
248
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
249
    end = time.perf_counter()
250
251
252
    return end - start


253
def run_mii(
254
    requests: list[SampleRequest],
255
256
257
258
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
259
260
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
261
    prompts = [request.prompt for request in requests]
262
263

    start = time.perf_counter()
264
    llm.generate(prompts, max_new_tokens=output_len)
265
    end = time.perf_counter()
266
267
    client = client(model)
    client.terminate_server()
268
269
270
    return end - start


271
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
272
                                     results: dict[str, Any]) -> None:
273
274
275
276
277
278
279
280
281
282
283
284
285
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
286
        write_to_json(pt_file, pt_records)
287
288


289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
309
310
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
311
312
313
314
315
316
317
318
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
            "Tokenizer/model must have chat template for sonnet dataset.")
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
319
320
321
322
323
324
325
326
327
328
329
330
331
    elif args.dataset_name == "hf":
        if args.backend != "vllm-chat":
            raise ValueError(
                "hf datasets only are supported by vllm-chat backend")
        # Choose between VisionArenaDataset and HuggingFaceDataset based on
        # provided parameters.
        dataset_cls = (VisionArenaDataset if args.dataset_path
                       == VisionArenaDataset.VISION_ARENA_DATASET_PATH
                       and args.hf_subset is None else HuggingFaceDataset)
        common_kwargs['dataset_subset'] = args.hf_subset
        common_kwargs['dataset_split'] = args.hf_split
        sample_kwargs["enable_multimodal_chat"] = True

332
333
334
335
336
337
338
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


339
def main(args: argparse.Namespace):
340
341
    if args.seed is None:
        args.seed = 0
342
343
344
    print(args)
    random.seed(args.seed)
    # Sample the requests.
345
346
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
347
    requests = get_requests(args, tokenizer)
348
349
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
350
    request_outputs: Optional[list[RequestOutput]] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
351
    if args.backend == "vllm":
352
        if args.async_engine:
353
354
355
356
357
358
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
359
                    args.disable_detokenize,
360
                ))
361
        else:
362
363
364
            elapsed_time, request_outputs = run_vllm(
                requests, args.n, EngineArgs.from_cli_args(args),
                args.disable_detokenize)
365
366
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
367
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
368
369
                              args.hf_max_batch_size, args.trust_remote_code,
                              args.disable_detokenize)
370
371
372
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
373
374
375
376
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
            requests, args.n, EngineArgs.from_cli_args(args),
            args.disable_detokenize)
377
378
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
            total_prompt_tokens += len(
                ro.prompt_token_ids) if ro.prompt_token_ids else 0
            total_output_tokens += sum(
                len(o.token_ids) for o in ro.outputs if o)
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
        total_num_tokens = sum(r.prompt_len + r.expected_output_len
                               for r in requests)
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
        print("\033[91mWARNING\033[0m: Multi-modal request with "
              f"{args.backend} backend detected. The "
402
403
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
404
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
405
406
        # vllm-chat backend counts the image tokens now

Woosuk Kwon's avatar
Woosuk Kwon committed
407
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
408
409
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
410
411
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
412

413
414
415
416
417
418
419
420
421
422
423
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
424
        save_to_pytorch_benchmark_format(args, results)
425

426

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
            stacklevel=2)
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
        print(
            "When dataset path is not set, it will default to random dataset")
        args.dataset_name = 'random'
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
            getattr(args, "hf_subset", None) is not None
            or getattr(args, "hf_split", None) is not None):
        warnings.warn("--hf-subset and --hf-split will be ignored \
                since --dataset-name is not 'hf'.",
                      stacklevel=2)
    elif args.dataset_name == "hf" and args.backend != "vllm-chat":
        raise ValueError(
            "When --dataset-name is 'hf', backend must be 'vllm-chat'")

    # --random-range-ratio: only used when dataset_name is 'random'
    if args.dataset_name != 'random' and args.random_range_ratio is not None:
        warnings.warn("--random-range-ratio will be ignored since \
                --dataset-name is not 'random'.",
                      stacklevel=2)

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
    if args.dataset_name not in {"random", "sonnet", None
                                 } and args.prefix_len is not None:
        warnings.warn("--prefix-len will be ignored since --dataset-name\
                 is not 'random', 'sonnet', or not set.",
                      stacklevel=2)

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
        raise ValueError(
            "LoRA benchmarking is only supported for vLLM backend")
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

    if args.backend in {"hf", "mii"} and getattr(args, "quantization",
                                                 None) is not None:
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
        raise ValueError(
            "Tokenizer must be the same as the model for MII backend.")


509
if __name__ == "__main__":
510
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
511
512
    parser.add_argument("--backend",
                        type=str,
513
                        choices=["vllm", "hf", "mii", "vllm-chat"],
Woosuk Kwon's avatar
Woosuk Kwon committed
514
                        default="vllm")
515
516
517
518
519
520
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
        default="sharegpt")
521
522
523
524
525
526
527
528
529
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
        "list[dict[..., value: <prompt_or_response>]]]]")
    parser.add_argument("--dataset-path",
530
                        type=str,
531
                        default=None,
532
                        help="Path to the dataset")
533
534
535
536
537
538
539
540
541
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
542
543
544
    parser.add_argument("--n",
                        type=int,
                        default=1,
545
                        help="Number of generated sequences per prompt.")
546
547
548
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
549
                        help="Number of prompts to process.")
550
551
552
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
553
                        help="Maximum batch size for HF backend.")
554
555
556
557
558
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
559
560
561
562
563
564
565
566
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
567
568
569
570
571
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=("Do not detokenize the response (i.e. do not include "
              "detokenization time in the measurement)"))
572
573
574
575
576
577
578
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")
579
580
581
582
583
584
585
586
587
    parser.add_argument("--prefix-len",
                        type=int,
                        default=None,
                        help="Number of prefix tokens per request."
                        "This is for the RandomDataset and SonnetDataset")
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
588
        default=None,
589
590
591
        help="Range of sampled ratio of input/output length, "
        "used only for RandomDataSet.",
    )
592

593
594
595
596
597
598
599
600
601
602
    # hf dtaset
    parser.add_argument("--hf-subset",
                        type=str,
                        default=None,
                        help="Subset of the HF dataset.")
    parser.add_argument("--hf-split",
                        type=str,
                        default=None,
                        help="Split of the HF dataset.")

603
    parser = AsyncEngineArgs.add_cli_args(parser)
604
    args = parser.parse_args()
605
606
    if args.tokenizer is None:
        args.tokenizer = args.model
607
    validate_args(args)
608
    main(args)