throughput.py 34.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Benchmark offline inference throughput."""
4

5
6
7
8
9
10
11
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
12
from typing import Any
13
14
15
16

import torch
import uvloop
from tqdm import tqdm
17
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
18
19
20
21
22
23

from vllm.benchmarks.datasets import (
    AIMODataset,
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
24
    MultiModalConversationDataset,
25
26
    PrefixRepetitionRandomDataset,
    RandomDataset,
27
28
    RandomDatasetForReranking,
    RandomMultiModalDataset,
29
30
31
32
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
33
34
    add_random_dataset_base_args,
    add_random_multimodal_dataset_args,
35
36
)
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
37
38
39
40
41
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
42
from vllm.tokenizers import TokenizerLike, get_tokenizer
43
from vllm.utils.async_utils import merge_async_iterators
44
45
46
47
48
49


def run_vllm(
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
50
    do_profile: bool,
51
    disable_detokenize: bool = False,
52
) -> tuple[float, list[RequestOutput] | None]:
53
    from vllm import LLM, SamplingParams
54

55
56
    llm = LLM(**dataclasses.asdict(engine_args))
    assert all(
57
58
59
60
61
62
63
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
64
    # Add the requests to the engine.
65
    prompts: list[TextPrompt | TokensPrompt] = []
66
67
    sampling_params: list[SamplingParams] = []
    for request in requests:
68
69
        prompt = (
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
70
            if "prompt_token_ids" in request.prompt
71
            else TextPrompt(prompt=request.prompt)
72
        )
73
74
75
76
77
        if request.multi_modal_data:
            assert isinstance(request.multi_modal_data, dict)
            prompt["multi_modal_data"] = request.multi_modal_data
        prompts.append(prompt)

78
79
80
81
82
83
84
85
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
86
87
            )
        )
88
    lora_requests: list[LoRARequest] | None = None
89
90
91
92
93
94
95
96
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]

    use_beam_search = False

    outputs = None
    if not use_beam_search:
        start = time.perf_counter()
97
98
        if do_profile:
            llm.start_profile()
99
100
101
        outputs = llm.generate(
            prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
        )
102
103
        if do_profile:
            llm.stop_profile()
104
105
106
107
108
        end = time.perf_counter()
    else:
        assert lora_requests is None, "BeamSearch API does not support LoRA"
        prompts = [request.prompt for request in requests]
        # output_len should be the same for all requests.
109
        output_len = requests[0].expected_output_len
110
111
112
        for request in requests:
            assert request.expected_output_len == output_len
        start = time.perf_counter()
113
114
        if do_profile:
            llm.start_profile()
115
116
117
118
119
120
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
121
122
            ),
        )
123
124
        if do_profile:
            llm.stop_profile()
125
126
127
128
129
        end = time.perf_counter()
    return end - start, outputs


def run_vllm_chat(
130
131
132
133
134
135
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    do_profile: bool,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
136
137
138
139
140
141
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
142

143
144
145
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
146
147
148
149
150
151
152
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
153
154
155
156
157
158
159
160
161
162
163
164
165

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
166
167
            )
        )
168
    start = time.perf_counter()
169
170
    if do_profile:
        llm.start_profile()
171
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
172
173
    if do_profile:
        llm.stop_profile()
174
175
176
177
178
179
180
181
    end = time.perf_counter()
    return end - start, outputs


async def run_vllm_async(
    requests: list[SampleRequest],
    n: int,
    engine_args: AsyncEngineArgs,
182
    do_profile: bool,
183
184
185
186
    disable_frontend_multiprocessing: bool = False,
    disable_detokenize: bool = False,
) -> float:
    from vllm import SamplingParams
187
    from vllm.entrypoints.openai.api_server import (
188
189
        build_async_engine_client_from_engine_args,
    )
190
191

    async with build_async_engine_client_from_engine_args(
192
193
194
        engine_args,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
    ) as llm:
195
        model_config = llm.model_config
196
        assert all(
197
198
199
200
201
202
203
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
204
205

        # Add the requests to the engine.
206
        prompts: list[TextPrompt | TokensPrompt] = []
207
        sampling_params: list[SamplingParams] = []
208
        lora_requests: list[LoRARequest | None] = []
209
        for request in requests:
210
211
            prompt = (
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
212
                if "prompt_token_ids" in request.prompt
213
                else TextPrompt(prompt=request.prompt)
214
            )
215
216
217
218
219

            if request.multi_modal_data:
                assert isinstance(request.multi_modal_data, dict)
                prompt["multi_modal_data"] = request.multi_modal_data

220
221
222
223
224
225
226
227
            sampling_params.append(
                SamplingParams(
                    n=n,
                    temperature=1.0,
                    top_p=1.0,
                    ignore_eos=True,
                    max_tokens=request.expected_output_len,
                    detokenize=not disable_detokenize,
228
229
                )
            )
230
            prompts.append(prompt)
231
232
233
234
            lora_requests.append(request.lora_request)

        generators = []
        start = time.perf_counter()
235
236
        if do_profile:
            await llm.start_profile()
237
238
239
240
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
241
242
243
244
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
245
246
        if do_profile:
            await llm.stop_profile()
247
248
249
250
251
252
253
        end = time.perf_counter()
        return end - start


def run_hf(
    requests: list[SampleRequest],
    model: str,
254
    tokenizer: TokenizerLike,
255
256
257
258
259
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
    disable_detokenize: bool = False,
) -> float:
260
261
262
    assert isinstance(tokenizer, PreTrainedTokenizerBase), (
        "the hf backend only supports HF tokenizers"
    )
263
    llm = AutoModelForCausalLM.from_pretrained(
264
        model, dtype=torch.float16, trust_remote_code=trust_remote_code
265
    )
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: list[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
288
289
290
291
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
292
293
294
295
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
296
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=True,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


319
320
321
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
322
323
324
325
326
327
328
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
329
330
331
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
        write_to_json(pt_file, pt_records)


def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
    }

351
352
353
354
    if args.dataset_name == "random" or (
        args.dataset_path is None
        and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
    ):
355
        sample_kwargs["range_ratio"] = args.random_range_ratio
356
357
358
359
360
361
362
363
364
365
366
367
368
        # prefer random_* arguments, fall back to regular arguments
        random_prefix_len = getattr(args, "random_prefix_len", None)
        sample_kwargs["prefix_len"] = (
            random_prefix_len if random_prefix_len is not None else args.prefix_len
        )
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len if random_input_len is not None else args.input_len
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len if random_output_len is not None else args.output_len
        )
369
370
371
372
373
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
374
375
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
376
377
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
378
379
            "Tokenizer/model must have chat template for sonnet dataset."
        )
380
381
382
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
383
384
385
386
        if args.input_len is not None:
            sample_kwargs["input_len"] = args.input_len
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
387
388
389
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
    elif args.dataset_name == "hf":
390
391
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
392
393
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
394
395
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
396
397
398
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = InstructCoderDataset
399
            common_kwargs["dataset_split"] = "train"
400
401
402
403
404
        elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = MultiModalConversationDataset
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
            sample_kwargs["enable_multimodal_chat"] = True
405
406
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
407
408
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
409
410
411
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
412
413
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
414
415
416
417
418
419
    elif args.dataset_name == "prefix_repetition":
        dataset_cls = PrefixRepetitionRandomDataset
        sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
        sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
        sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
        sample_kwargs["output_len"] = args.prefix_repetition_output_len
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    elif args.dataset_name == "random-mm":
        dataset_cls = RandomMultiModalDataset
        # prefer random_* arguments, fall back to regular arguments
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len
            if random_input_len is not None
            else getattr(args, "input_len", None)
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len
            if random_output_len is not None
            else getattr(args, "output_len", None)
        )
        sample_kwargs["base_items_per_request"] = getattr(
            args, "random_mm_base_items_per_request", None
        )
        sample_kwargs["num_mm_items_range_ratio"] = getattr(
            args, "random_mm_num_mm_items_range_ratio", None
        )
        sample_kwargs["limit_mm_per_prompt"] = getattr(
            args, "random_mm_limit_mm_per_prompt", None
        )
        sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
        sample_kwargs["enable_multimodal_chat"] = True
        random_prefix_len = getattr(args, "random_prefix_len", None)
        prefix_len = getattr(args, "prefix_len", None)
        sample_kwargs["prefix_len"] = (
            random_prefix_len if random_prefix_len is not None else prefix_len
        )
        sample_kwargs["range_ratio"] = args.random_range_ratio
    elif args.dataset_name == "random-rerank":
        dataset_cls = RandomDatasetForReranking
        # prefer random_* arguments, fall back to regular arguments
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len
            if random_input_len is not None
            else getattr(args, "input_len", None)
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len
            if random_output_len is not None
            else getattr(args, "output_len", None)
        )
        sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
        sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
        sample_kwargs["range_ratio"] = args.random_range_ratio
470
471
472
473
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
    requests = filter_requests_for_dp(requests, args.data_parallel_size)
    return requests


def filter_requests_for_dp(requests, data_parallel_size):
    # Note(zhuohan): The way we get data_parallel_rank is hacky and only
    # works for external launcher mode. Should be cleaned up and deprecated
    # in the future with a better vLLM distributed process design.
    if data_parallel_size == 1:
        return requests

    global_rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    data_parallel_rank = global_rank // (world_size // data_parallel_size)
489
490
491
492
493
    return [
        r
        for i, r in enumerate(requests)
        if i % data_parallel_size == data_parallel_rank
    ]
494
495
496
497
498
499
500
501
502
503
504
505


def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
506
507
            stacklevel=2,
        )
508
509
510
511
512
513
514
515
516
517
518
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
519
520
521
522
523
    if (
        not args.dataset
        and not args.dataset_path
        and args.dataset_name not in {"prefix_repetition"}
    ):
524
525
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
526
527
528
529
530
531
        random_input_len = getattr(args, "random_input_len", None)
        if args.input_len is None and random_input_len is None:
            raise ValueError(
                "Either --input-len or --random-input-len must be provided "
                "for a random dataset"
            )
532
533
534
535
536

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
537
538
539
540
541
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
542
                since --dataset-name is not 'hf'.",
543
544
            stacklevel=2,
        )
545
546
    elif args.dataset_name == "hf":
        if args.dataset_path in (
547
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
548
            | MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
549
550
551
552
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
553
            )
554
555
556
557
558
559
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
560
            )
561
        else:
562
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
563

564
565
566
567
568
569
    # --random-range-ratio: only used when dataset_name is 'random',
    # 'random-mm', or 'random-rerank'
    if (
        args.dataset_name not in {"random", "random-mm", "random-rerank"}
        and args.random_range_ratio is not None
    ):
570
571
        warnings.warn(
            "--random-range-ratio will be ignored since \
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
                --dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
            stacklevel=2,
        )

    # --random-batch-size: only used when dataset_name is 'random-rerank'
    if (
        args.dataset_name != "random-rerank"
        and getattr(args, "random_batch_size", None) is not None
    ) and args.random_batch_size != 1:
        warnings.warn(
            "--random-batch-size will be ignored since \
                    --dataset-name is not 'random-rerank'.",
            stacklevel=2,
        )

    # --no-reranker: only used when dataset_name is 'random-rerank'
    if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
        warnings.warn(
            "--no-reranker will be ignored since \
                --dataset-name is not 'random-rerank'.",
592
593
            stacklevel=2,
        )
594

595
596
    # --prefix-len: only used when dataset_name is 'random', 'random-mm',
    # 'sonnet', or not set.
597
    if (
598
        args.dataset_name not in {"random", "random-mm", "sonnet", None}
599
600
601
602
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
603
                 is not 'random', 'random-mm', 'sonnet', or not set.",
604
605
            stacklevel=2,
        )
606

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    # === Random Dataset Argument Conflict Detection ===
    # Check for conflicts between regular and random arguments when using
    # random datasets
    if args.dataset_name in {"random", "random-mm", "random-rerank"}:
        random_input_len = getattr(args, "random_input_len", None)
        random_output_len = getattr(args, "random_output_len", None)
        random_prefix_len = getattr(args, "random_prefix_len", None)

        if args.input_len is not None and random_input_len is not None:
            warnings.warn(
                "Both --input-len and --random-input-len are specified. "
                "The random version (--random-input-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )
        if args.output_len is not None and random_output_len is not None:
            warnings.warn(
                "Both --output-len and --random-output-len are specified. "
                "The random version (--random-output-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )
        if args.prefix_len is not None and random_prefix_len is not None:
            warnings.warn(
                "Both --prefix-len and --random-prefix-len are specified. "
                "The random version (--random-prefix-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )

637
638
    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
639
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
640
641
642
643
644
645
646
647
648
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

649
650
651
652
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
653
654
655
656
657
658
659
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
660
        raise ValueError("Tokenizer must be the same as the model for MII backend.")
661
662

    if args.data_parallel_size > 1 and (
663
664
        args.distributed_executor_backend != "external_launcher" or args.async_engine
    ):
665
666
667
668
        # --data-parallel is not supported fully.
        # Old issue: https://github.com/vllm-project/vllm/issues/16222
        # Currently we only support data parallel with external launcher
        # mode (i.e., launch with toruchrun).
669
        raise ValueError(
670
671
            "Data parallel is only supported with external launcher mode "
            "with synchronous engine in offline benchmark, "
672
673
            "please use benchmark serving instead"
        )
674
675
676


def add_cli_args(parser: argparse.ArgumentParser):
677
678
679
680
681
682
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
683
684
685
    parser.add_argument(
        "--dataset-name",
        type=str,
686
687
688
689
690
691
692
693
694
695
        choices=[
            "sharegpt",
            "random",
            "sonnet",
            "burstgpt",
            "hf",
            "prefix_repetition",
            "random-mm",
            "random-rerank",
        ],
696
        help="Name of the dataset to benchmark on.",
697
698
        default="sharegpt",
    )
699
700
701
702
703
704
705
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
724
    parser.add_argument(
725
726
727
728
729
730
731
732
733
734
735
736
737
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
    parser.add_argument(
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
    parser.add_argument(
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
    parser.add_argument(
        "--output-json",
738
739
        type=str,
        default=None,
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
    parser.add_argument(
        "--disable-frontend-multiprocessing",
        action="store_true",
        default=False,
        help="Disable decoupled async engine frontend.",
    )
754
755
756
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
757
758
759
760
761
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
762
763
764
765
766
767
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
768
769
        "a relative path, or a Hugging Face model identifier.",
    )
770
771
772
773
774
775
776
777
778
    parser.add_argument(
        "--prefix-len",
        type=int,
        default=0,
        help="Number of fixed prefix tokens before the random "
        "context in a request (default: 0).",
    )

    # hf dtaset
779
    parser.add_argument(
780
781
782
783
        "--hf-subset",
        type=str,
        default=None,
        help="Subset of the HF dataset.",
784
785
    )
    parser.add_argument(
786
787
788
789
        "--hf-split",
        type=str,
        default=None,
        help="Split of the HF dataset.",
790
    )
791
792
793
794
    parser.add_argument(
        "--profile",
        action="store_true",
        default=False,
795
        help="Use vLLM Profiling. --profiler-config must be provided on the server.",
796
    )
797

798
    # prefix repetition dataset
799
    parser.add_argument(
800
801
802
803
804
805
        "--prefix-repetition-prefix-len",
        type=int,
        default=None,
        help="Number of prefix tokens per request, used only for prefix "
        "repetition dataset.",
    )
806
    parser.add_argument(
807
808
809
810
811
812
        "--prefix-repetition-suffix-len",
        type=int,
        default=None,
        help="Number of suffix tokens per request, used only for prefix "
        "repetition dataset. Total input length is prefix_len + suffix_len.",
    )
813
    parser.add_argument(
814
815
816
817
818
819
        "--prefix-repetition-num-prefixes",
        type=int,
        default=None,
        help="Number of prefixes to generate, used only for prefix repetition "
        "dataset. Prompts per prefix is num_requests // num_prefixes.",
    )
820
    parser.add_argument(
821
822
823
824
825
826
827
        "--prefix-repetition-output-len",
        type=int,
        default=None,
        help="Number of output tokens per request, used only for prefix "
        "repetition dataset.",
    )

828
829
830
831
    # (random, random-mm, random-rerank)
    add_random_dataset_base_args(parser)
    add_random_multimodal_dataset_args(parser)

832
833
834
835
836
837
838
839
840
    parser = AsyncEngineArgs.add_cli_args(parser)


def main(args: argparse.Namespace):
    validate_args(args)
    if args.seed is None:
        args.seed = 0
    random.seed(args.seed)
    # Sample the requests.
841
842
843
844
845
846
847
848
849
850
    if (
        args.backend == "hf" or args.backend == "mii"
    ) and args.tokenizer_mode == "auto":
        # mistral_common tokenizer is only supported on vllm and vllm-chat backends;
        # for hf and mii backends, we use hf tokenizer
        args.tokenizer_mode = "hf"
    tokenizer = get_tokenizer(
        args.tokenizer,
        tokenizer_mode=args.tokenizer_mode,
        trust_remote_code=args.trust_remote_code,
851
    )
852
    requests = get_requests(args, tokenizer)
853
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
854
    request_outputs: list[RequestOutput] | None = None
855
856
857
858
859
860
861
    if args.backend == "vllm":
        if args.async_engine:
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
862
863
864
                    disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
                    disable_detokenize=args.disable_detokenize,
                    do_profile=args.profile,
865
866
                )
            )
867
868
        else:
            elapsed_time, request_outputs = run_vllm(
869
870
871
                requests,
                args.n,
                EngineArgs.from_cli_args(args),
872
                disable_detokenize=args.disable_detokenize,
873
874
                do_profile=args.profile,
            )
875
876
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
877
        if args.profile:
878
879
880
881
882
883
884
885
886
887
            raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
        )
888
889
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
890
891
892
893
894
895
            requests,
            args.n,
            EngineArgs.from_cli_args(args),
            disable_detokenize=args.disable_detokenize,
            do_profile=args.profile,
        )
896
897
898
899
900
901
902
903
904
905
906
    else:
        raise ValueError(f"Unknown backend: {args.backend}")

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
907
908
909
910
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
911
912
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
913
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
914
915
916
917
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
918
919
920
921
922
923
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
924
925
926
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
        # vllm-chat backend counts the image tokens now

927
928
929
930
931
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
        save_to_pytorch_benchmark_format(args, results)