"README_ORIGIN.md" did not exist on "58f5a59769b89a9457dfbedaac9d200bb100be78"
throughput.py 35.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Benchmark offline inference throughput."""
4

5
6
7
8
9
10
import argparse
import json
import os
import random
import time
import warnings
11
from typing import Any
12
13
14
15

import torch
import uvloop
from tqdm import tqdm
16
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
17
18
19

from vllm.benchmarks.datasets import (
    AIMODataset,
20
    ASRDataset,
21
22
23
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
24
    MultiModalConversationDataset,
25
26
    PrefixRepetitionRandomDataset,
    RandomDataset,
27
28
    RandomDatasetForReranking,
    RandomMultiModalDataset,
29
30
31
32
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
33
34
    add_random_dataset_base_args,
    add_random_multimodal_dataset_args,
35
36
)
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
37
38
39
40
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
41
from vllm.platforms import current_platform
42
from vllm.sampling_params import BeamSearchParams
43
from vllm.tokenizers import TokenizerLike, get_tokenizer
44
from vllm.utils.async_utils import merge_async_iterators
45
46
47
48
49
50


def run_vllm(
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
51
    do_profile: bool,
52
    disable_detokenize: bool = False,
53
) -> tuple[float, list[RequestOutput] | None]:
54
    from vllm import LLM, SamplingParams
55

56
    llm = LLM.from_engine_args(engine_args)
57
    assert all(
58
59
60
61
62
63
64
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
65
    # Add the requests to the engine.
66
    prompts: list[TextPrompt | TokensPrompt] = []
67
68
    sampling_params: list[SamplingParams] = []
    for request in requests:
69
70
        prompt = (
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
71
            if "prompt_token_ids" in request.prompt
72
            else TextPrompt(prompt=request.prompt)
73
        )
74
75
76
77
78
        if request.multi_modal_data:
            assert isinstance(request.multi_modal_data, dict)
            prompt["multi_modal_data"] = request.multi_modal_data
        prompts.append(prompt)

79
80
81
82
83
84
85
86
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
87
88
            )
        )
89
    lora_requests: list[LoRARequest] | None = None
90
91
92
93
94
95
96
97
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]

    use_beam_search = False

    outputs = None
    if not use_beam_search:
        start = time.perf_counter()
98
99
        if do_profile:
            llm.start_profile()
100
101
102
        outputs = llm.generate(
            prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
        )
103
104
        if do_profile:
            llm.stop_profile()
105
106
107
108
109
        end = time.perf_counter()
    else:
        assert lora_requests is None, "BeamSearch API does not support LoRA"
        prompts = [request.prompt for request in requests]
        # output_len should be the same for all requests.
110
        output_len = requests[0].expected_output_len
111
112
113
        for request in requests:
            assert request.expected_output_len == output_len
        start = time.perf_counter()
114
115
        if do_profile:
            llm.start_profile()
116
117
118
119
120
121
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
122
123
            ),
        )
124
125
        if do_profile:
            llm.stop_profile()
126
127
128
129
130
        end = time.perf_counter()
    return end - start, outputs


def run_vllm_chat(
131
132
133
134
135
136
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    do_profile: bool,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
137
138
139
140
141
142
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
143

144
    llm = LLM.from_engine_args(engine_args)
145
146

    assert all(
147
148
149
150
151
152
153
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
154
155
156
157
158
159
160
161
162
163
164
165
166

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
167
168
            )
        )
169
    start = time.perf_counter()
170
171
    if do_profile:
        llm.start_profile()
172
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
173
174
    if do_profile:
        llm.stop_profile()
175
176
177
178
179
180
181
182
    end = time.perf_counter()
    return end - start, outputs


async def run_vllm_async(
    requests: list[SampleRequest],
    n: int,
    engine_args: AsyncEngineArgs,
183
    do_profile: bool,
184
185
186
    disable_detokenize: bool = False,
) -> float:
    from vllm import SamplingParams
187
    from vllm.entrypoints.openai.api_server import (
188
189
        build_async_engine_client_from_engine_args,
    )
190
191

    async with build_async_engine_client_from_engine_args(
192
193
        engine_args,
    ) as llm:
194
        model_config = llm.model_config
195
        assert all(
196
197
198
199
200
201
202
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
203
204

        # Add the requests to the engine.
205
        prompts: list[TextPrompt | TokensPrompt] = []
206
        sampling_params: list[SamplingParams] = []
207
        lora_requests: list[LoRARequest | None] = []
208
        for request in requests:
209
210
            prompt = (
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
211
                if "prompt_token_ids" in request.prompt
212
                else TextPrompt(prompt=request.prompt)
213
            )
214
215
216
217
218

            if request.multi_modal_data:
                assert isinstance(request.multi_modal_data, dict)
                prompt["multi_modal_data"] = request.multi_modal_data

219
220
221
222
223
224
225
226
            sampling_params.append(
                SamplingParams(
                    n=n,
                    temperature=1.0,
                    top_p=1.0,
                    ignore_eos=True,
                    max_tokens=request.expected_output_len,
                    detokenize=not disable_detokenize,
227
228
                )
            )
229
            prompts.append(prompt)
230
231
232
233
            lora_requests.append(request.lora_request)

        generators = []
        start = time.perf_counter()
234
235
        if do_profile:
            await llm.start_profile()
236
237
238
239
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
240
241
242
243
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
244
245
        if do_profile:
            await llm.stop_profile()
246
247
248
249
250
251
252
        end = time.perf_counter()
        return end - start


def run_hf(
    requests: list[SampleRequest],
    model: str,
253
    tokenizer: TokenizerLike,
254
255
256
257
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
    disable_detokenize: bool = False,
258
259
    dtype: torch.dtype | None = torch.float16,
    enable_torch_compile: bool = False,
260
) -> float:
261
262
263
    assert isinstance(tokenizer, PreTrainedTokenizerBase), (
        "the hf backend only supports HF tokenizers"
    )
264
    llm = AutoModelForCausalLM.from_pretrained(
265
        model, dtype=dtype, trust_remote_code=trust_remote_code
266
    )
267
268
269
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
270
271
272
    llm = llm.to(current_platform.device_type)
    if enable_torch_compile:
        llm = torch.compile(llm)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: list[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
291
292
293
294
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
295
296
297
298
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
299
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
300
        llm_outputs = llm.generate(
301
            input_ids=input_ids.to(current_platform.device_type),
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            do_sample=True,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


322
323
324
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
325
326
327
328
329
330
331
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
332
333
334
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
        write_to_json(pt_file, pt_records)


def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
351
        "lora_assignment": getattr(args, "lora_assignment", "random"),
352
353
354
        "num_requests": args.num_prompts,
    }

355
356
357
358
    if args.dataset_name == "random" or (
        args.dataset_path is None
        and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
    ):
359
        sample_kwargs["range_ratio"] = args.random_range_ratio
360
361
362
363
364
365
366
367
368
369
370
371
372
        # prefer random_* arguments, fall back to regular arguments
        random_prefix_len = getattr(args, "random_prefix_len", None)
        sample_kwargs["prefix_len"] = (
            random_prefix_len if random_prefix_len is not None else args.prefix_len
        )
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len if random_input_len is not None else args.input_len
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len if random_output_len is not None else args.output_len
        )
373
374
375
376
377
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
378
379
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
380
381
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
382
383
            "Tokenizer/model must have chat template for sonnet dataset."
        )
384
385
386
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
387
388
389
390
        if args.input_len is not None:
            sample_kwargs["input_len"] = args.input_len
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
391
392
393
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
    elif args.dataset_name == "hf":
394
395
        if args.output_len is not None:
            sample_kwargs["output_len"] = args.output_len
396
397
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
398
399
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
400
401
402
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = InstructCoderDataset
403
            common_kwargs["dataset_split"] = "train"
404
405
406
407
408
        elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = MultiModalConversationDataset
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
            sample_kwargs["enable_multimodal_chat"] = True
409
410
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
411
412
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
413
414
415
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
416
417
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
418
419
420
421
422
423
        elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ASRDataset
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
            sample_kwargs["asr_min_audio_len_sec"] = args.asr_min_audio_len_sec
            sample_kwargs["asr_max_audio_len_sec"] = args.asr_max_audio_len_sec
424
425
426
427
428
429
    elif args.dataset_name == "prefix_repetition":
        dataset_cls = PrefixRepetitionRandomDataset
        sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
        sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
        sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
        sample_kwargs["output_len"] = args.prefix_repetition_output_len
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    elif args.dataset_name == "random-mm":
        dataset_cls = RandomMultiModalDataset
        # prefer random_* arguments, fall back to regular arguments
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len
            if random_input_len is not None
            else getattr(args, "input_len", None)
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len
            if random_output_len is not None
            else getattr(args, "output_len", None)
        )
        sample_kwargs["base_items_per_request"] = getattr(
            args, "random_mm_base_items_per_request", None
        )
        sample_kwargs["num_mm_items_range_ratio"] = getattr(
            args, "random_mm_num_mm_items_range_ratio", None
        )
        sample_kwargs["limit_mm_per_prompt"] = getattr(
            args, "random_mm_limit_mm_per_prompt", None
        )
        sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
        sample_kwargs["enable_multimodal_chat"] = True
        random_prefix_len = getattr(args, "random_prefix_len", None)
        prefix_len = getattr(args, "prefix_len", None)
        sample_kwargs["prefix_len"] = (
            random_prefix_len if random_prefix_len is not None else prefix_len
        )
        sample_kwargs["range_ratio"] = args.random_range_ratio
    elif args.dataset_name == "random-rerank":
        dataset_cls = RandomDatasetForReranking
        # prefer random_* arguments, fall back to regular arguments
        random_input_len = getattr(args, "random_input_len", None)
        sample_kwargs["input_len"] = (
            random_input_len
            if random_input_len is not None
            else getattr(args, "input_len", None)
        )
        random_output_len = getattr(args, "random_output_len", None)
        sample_kwargs["output_len"] = (
            random_output_len
            if random_output_len is not None
            else getattr(args, "output_len", None)
        )
        sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
        sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
        sample_kwargs["range_ratio"] = args.random_range_ratio
480
481
482
483
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
    requests = filter_requests_for_dp(requests, args.data_parallel_size)
    return requests


def filter_requests_for_dp(requests, data_parallel_size):
    # Note(zhuohan): The way we get data_parallel_rank is hacky and only
    # works for external launcher mode. Should be cleaned up and deprecated
    # in the future with a better vLLM distributed process design.
    if data_parallel_size == 1:
        return requests

    global_rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    data_parallel_rank = global_rank // (world_size // data_parallel_size)
499
500
501
502
503
    return [
        r
        for i, r in enumerate(requests)
        if i % data_parallel_size == data_parallel_rank
    ]
504
505
506
507
508
509
510
511
512
513
514
515


def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
516
517
            stacklevel=2,
        )
518
519
520
521
522
523
524
525
526
527
528
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
529
530
531
532
533
    if (
        not args.dataset
        and not args.dataset_path
        and args.dataset_name not in {"prefix_repetition"}
    ):
534
535
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
536
537
538
539
540
541
        random_input_len = getattr(args, "random_input_len", None)
        if args.input_len is None and random_input_len is None:
            raise ValueError(
                "Either --input-len or --random-input-len must be provided "
                "for a random dataset"
            )
542
543
544
545
546

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
547
548
549
550
551
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
552
                since --dataset-name is not 'hf'.",
553
554
            stacklevel=2,
        )
555
556
    elif args.dataset_name == "hf":
        if args.dataset_path in (
557
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
558
            | MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
559
560
561
562
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
563
            )
564
565
566
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
567
            | ASRDataset.SUPPORTED_DATASET_PATHS
568
569
570
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
571
            )
572
        else:
573
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
574

575
576
577
578
579
580
    # --random-range-ratio: only used when dataset_name is 'random',
    # 'random-mm', or 'random-rerank'
    if (
        args.dataset_name not in {"random", "random-mm", "random-rerank"}
        and args.random_range_ratio is not None
    ):
581
582
        warnings.warn(
            "--random-range-ratio will be ignored since \
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                --dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
            stacklevel=2,
        )

    # --random-batch-size: only used when dataset_name is 'random-rerank'
    if (
        args.dataset_name != "random-rerank"
        and getattr(args, "random_batch_size", None) is not None
    ) and args.random_batch_size != 1:
        warnings.warn(
            "--random-batch-size will be ignored since \
                    --dataset-name is not 'random-rerank'.",
            stacklevel=2,
        )

    # --no-reranker: only used when dataset_name is 'random-rerank'
    if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
        warnings.warn(
            "--no-reranker will be ignored since \
                --dataset-name is not 'random-rerank'.",
603
604
            stacklevel=2,
        )
605

606
607
    # --prefix-len: only used when dataset_name is 'random', 'random-mm',
    # 'sonnet', or not set.
608
    if (
609
        args.dataset_name not in {"random", "random-mm", "sonnet", None}
610
611
612
613
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
614
                 is not 'random', 'random-mm', 'sonnet', or not set.",
615
616
            stacklevel=2,
        )
617

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
    # === Random Dataset Argument Conflict Detection ===
    # Check for conflicts between regular and random arguments when using
    # random datasets
    if args.dataset_name in {"random", "random-mm", "random-rerank"}:
        random_input_len = getattr(args, "random_input_len", None)
        random_output_len = getattr(args, "random_output_len", None)
        random_prefix_len = getattr(args, "random_prefix_len", None)

        if args.input_len is not None and random_input_len is not None:
            warnings.warn(
                "Both --input-len and --random-input-len are specified. "
                "The random version (--random-input-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )
        if args.output_len is not None and random_output_len is not None:
            warnings.warn(
                "Both --output-len and --random-output-len are specified. "
                "The random version (--random-output-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )
        if args.prefix_len is not None and random_prefix_len is not None:
            warnings.warn(
                "Both --prefix-len and --random-prefix-len are specified. "
                "The random version (--random-prefix-len) will be preferred "
                "in this run.",
                stacklevel=2,
            )

648
649
    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
650
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
651
652
653
654
655
656
657
658
659
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

660
661
662
663
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
664
665
666
667
668
669
670
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
671
        raise ValueError("Tokenizer must be the same as the model for MII backend.")
672
673

    if args.data_parallel_size > 1 and (
674
675
        args.distributed_executor_backend != "external_launcher" or args.async_engine
    ):
676
677
678
679
        # --data-parallel is not supported fully.
        # Old issue: https://github.com/vllm-project/vllm/issues/16222
        # Currently we only support data parallel with external launcher
        # mode (i.e., launch with toruchrun).
680
        raise ValueError(
681
682
            "Data parallel is only supported with external launcher mode "
            "with synchronous engine in offline benchmark, "
683
684
            "please use benchmark serving instead"
        )
685
686
687


def add_cli_args(parser: argparse.ArgumentParser):
688
689
690
691
692
693
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
694
695
696
    parser.add_argument(
        "--dataset-name",
        type=str,
697
698
699
700
701
702
703
704
705
706
        choices=[
            "sharegpt",
            "random",
            "sonnet",
            "burstgpt",
            "hf",
            "prefix_repetition",
            "random-mm",
            "random-rerank",
        ],
707
        help="Name of the dataset to benchmark on.",
708
709
        default="sharegpt",
    )
710
711
712
713
714
715
716
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
735
    parser.add_argument(
736
737
738
739
740
741
742
743
744
745
746
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
    parser.add_argument(
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
    parser.add_argument(
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
747
748
749
750
751
752
    parser.add_argument(
        "--hf-enable-torch-compile",
        action="store_true",
        default=False,
        help="Enable Torch compile for HF backend.",
    )
753
754
    parser.add_argument(
        "--output-json",
755
756
        type=str,
        default=None,
757
758
759
760
761
762
763
764
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
765
766
767
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
768
769
770
771
772
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
773
774
775
776
777
778
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
779
780
        "a relative path, or a Hugging Face model identifier.",
    )
781
782
783
784
785
786
787
788
789
    parser.add_argument(
        "--lora-assignment",
        type=str,
        default="random",
        choices=["random", "round-robin"],
        help="Strategy for assigning LoRA adapters to requests. "
        "'random' (default) selects a LoRA at random for each request. "
        "'round-robin' cycles through LoRAs deterministically.",
    )
790
791
792
793
794
795
796
797
798
    parser.add_argument(
        "--prefix-len",
        type=int,
        default=0,
        help="Number of fixed prefix tokens before the random "
        "context in a request (default: 0).",
    )

    # hf dtaset
799
    parser.add_argument(
800
801
802
803
        "--hf-subset",
        type=str,
        default=None,
        help="Subset of the HF dataset.",
804
805
    )
    parser.add_argument(
806
807
808
809
        "--hf-split",
        type=str,
        default=None,
        help="Split of the HF dataset.",
810
    )
811
812
813
814
    parser.add_argument(
        "--profile",
        action="store_true",
        default=False,
815
        help="Use vLLM Profiling. --profiler-config must be provided on the server.",
816
    )
817

818
    # prefix repetition dataset
819
    parser.add_argument(
820
821
822
823
824
825
        "--prefix-repetition-prefix-len",
        type=int,
        default=None,
        help="Number of prefix tokens per request, used only for prefix "
        "repetition dataset.",
    )
826
    parser.add_argument(
827
828
829
830
831
832
        "--prefix-repetition-suffix-len",
        type=int,
        default=None,
        help="Number of suffix tokens per request, used only for prefix "
        "repetition dataset. Total input length is prefix_len + suffix_len.",
    )
833
    parser.add_argument(
834
835
836
837
838
839
        "--prefix-repetition-num-prefixes",
        type=int,
        default=None,
        help="Number of prefixes to generate, used only for prefix repetition "
        "dataset. Prompts per prefix is num_requests // num_prefixes.",
    )
840
    parser.add_argument(
841
842
843
844
845
846
847
        "--prefix-repetition-output-len",
        type=int,
        default=None,
        help="Number of output tokens per request, used only for prefix "
        "repetition dataset.",
    )

848
849
850
851
    # (random, random-mm, random-rerank)
    add_random_dataset_base_args(parser)
    add_random_multimodal_dataset_args(parser)

852
853
854
855
856
857
858
859
860
861
862
863
864
865
    # ASR dataset
    parser.add_argument(
        "--asr-min-audio-len-sec",
        type=float,
        default=0.0,
        help="Minimum audio duration in seconds for ASR dataset filtering.",
    )
    parser.add_argument(
        "--asr-max-audio-len-sec",
        type=float,
        default=float("inf"),
        help="Maximum audio duration in seconds for ASR dataset filtering.",
    )

866
867
868
869
870
871
872
873
874
    parser = AsyncEngineArgs.add_cli_args(parser)


def main(args: argparse.Namespace):
    validate_args(args)
    if args.seed is None:
        args.seed = 0
    random.seed(args.seed)
    # Sample the requests.
875
876
877
878
879
880
881
882
883
884
    if (
        args.backend == "hf" or args.backend == "mii"
    ) and args.tokenizer_mode == "auto":
        # mistral_common tokenizer is only supported on vllm and vllm-chat backends;
        # for hf and mii backends, we use hf tokenizer
        args.tokenizer_mode = "hf"
    tokenizer = get_tokenizer(
        args.tokenizer,
        tokenizer_mode=args.tokenizer_mode,
        trust_remote_code=args.trust_remote_code,
885
    )
886
    requests = get_requests(args, tokenizer)
887
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
888
    request_outputs: list[RequestOutput] | None = None
889
890
891
892
893
894
895
    if args.backend == "vllm":
        if args.async_engine:
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
896
897
                    disable_detokenize=args.disable_detokenize,
                    do_profile=args.profile,
898
899
                )
            )
900
901
        else:
            elapsed_time, request_outputs = run_vllm(
902
903
904
                requests,
                args.n,
                EngineArgs.from_cli_args(args),
905
                disable_detokenize=args.disable_detokenize,
906
907
                do_profile=args.profile,
            )
908
909
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
910
        if args.profile:
911
912
913
914
915
916
917
918
919
            raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
920
921
            dtype=args.dtype,
            enable_torch_compile=args.hf_enable_torch_compile,
922
        )
923
924
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
925
926
927
928
929
930
            requests,
            args.n,
            EngineArgs.from_cli_args(args),
            disable_detokenize=args.disable_detokenize,
            do_profile=args.profile,
        )
931
932
933
934
935
936
937
938
939
940
941
    else:
        raise ValueError(f"Unknown backend: {args.backend}")

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
942
943
944
945
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
946
947
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
948
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
949
950
951
952
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
953
954
955
956
957
958
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
959
960
961
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
        # vllm-chat backend counts the image tokens now

962
963
964
965
966
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
        save_to_pytorch_benchmark_format(args, results)