benchmark_throughput.py 25.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Benchmark offline inference throughput."""
4

5
import argparse
6
import dataclasses
7
import json
8
import os
9
10
import random
import time
11
import warnings
12
from typing import Any, Optional, Union
13

14
import torch
15
import uvloop
16
from tqdm import tqdm
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from benchmark_dataset import (
    AIMODataset,
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
31
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
32
from vllm.entrypoints.openai.api_server import (
33
34
    build_async_engine_client_from_engine_args,
)
35
from vllm.inputs import TextPrompt, TokensPrompt
36
from vllm.lora.request import LoRARequest
37
from vllm.outputs import RequestOutput
38
from vllm.sampling_params import BeamSearchParams
39
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
40

41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
def run_vllm(
43
    requests: list[SampleRequest],
44
    n: int,
45
    engine_args: EngineArgs,
46
    disable_detokenize: bool = False,
47
) -> tuple[float, Optional[list[RequestOutput]]]:
48
    from vllm import LLM, SamplingParams
49

50
    llm = LLM(**dataclasses.asdict(engine_args))
51
    assert all(
52
53
54
55
56
57
58
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
Zhuohan Li's avatar
Zhuohan Li committed
59
    # Add the requests to the engine.
60
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
61
    sampling_params: list[SamplingParams] = []
62
    for request in requests:
63
        prompts.append(
64
65
66
67
68
69
70
71
72
            TokensPrompt(
                prompt_token_ids=request.prompt["prompt_token_ids"],
                multi_modal_data=request.multi_modal_data,
            )
            if "prompt_token_ids" in request.prompt
            else TextPrompt(
                prompt=request.prompt, multi_modal_data=request.multi_modal_data
            )
        )
73
74
75
        sampling_params.append(
            SamplingParams(
                n=n,
76
                temperature=1.0,
77
78
                top_p=1.0,
                ignore_eos=True,
79
                max_tokens=request.expected_output_len,
80
                detokenize=not disable_detokenize,
81
82
            )
        )
83
    lora_requests: Optional[list[LoRARequest]] = None
84
85
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
86

87
88
    use_beam_search = False

89
    outputs = None
90
    if not use_beam_search:
91
        start = time.perf_counter()
92
93
94
        outputs = llm.generate(
            prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
        )
95
96
        end = time.perf_counter()
    else:
97
        assert lora_requests is None, "BeamSearch API does not support LoRA"
98
        prompts = [request.prompt for request in requests]
99
        # output_len should be the same for all requests.
100
        output_len = requests[0].expected_output_len
101
102
        for request in requests:
            assert request.expected_output_len == output_len
103
        start = time.perf_counter()
104
105
106
107
108
109
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
110
111
            ),
        )
112
        end = time.perf_counter()
113
114
115
116
    return end - start, outputs


def run_vllm_chat(
117
118
119
120
121
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
122
123
124
125
126
127
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
128

129
130
131
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
132
133
134
135
136
137
138
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
139
140
141
142
143
144
145
146
147
148
149
150
151

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
152
153
            )
        )
154
155
156
157
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
158
159


160
async def run_vllm_async(
161
    requests: list[SampleRequest],
162
    n: int,
163
    engine_args: AsyncEngineArgs,
164
    disable_frontend_multiprocessing: bool = False,
165
    disable_detokenize: bool = False,
166
167
168
169
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
170
171
        engine_args,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
172
    ) as llm:
173
        model_config = await llm.get_model_config()
174
        assert all(
175
176
177
178
179
180
181
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
182
183

        # Add the requests to the engine.
184
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
185
186
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
187
        for request in requests:
188
            prompts.append(
189
190
191
192
193
194
195
196
197
                TokensPrompt(
                    prompt_token_ids=request.prompt["prompt_token_ids"],
                    multi_modal_data=request.multi_modal_data,
                )
                if "prompt_token_ids" in request.prompt
                else TextPrompt(
                    prompt=request.prompt, multi_modal_data=request.multi_modal_data
                )
            )
198
199
200
            sampling_params.append(
                SamplingParams(
                    n=n,
201
                    temperature=1.0,
202
203
                    top_p=1.0,
                    ignore_eos=True,
204
                    max_tokens=request.expected_output_len,
205
                    detokenize=not disable_detokenize,
206
207
                )
            )
208
            lora_requests.append(request.lora_request)
209
210
211

        generators = []
        start = time.perf_counter()
212
213
214
215
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
216
217
218
219
220
221
222
223
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


224
def run_hf(
225
    requests: list[SampleRequest],
226
227
228
229
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
230
    trust_remote_code: bool,
231
    disable_detokenize: bool = False,
232
) -> float:
233
    llm = AutoModelForCausalLM.from_pretrained(
234
235
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
    )
236
237
238
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
239
240
241
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
242
    start = time.perf_counter()
243
    batch: list[str] = []
244
245
246
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
247
248
249
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
250
251
252
253
254
255
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
256
257
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
258
259
260
261
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
262
263
264
265
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
266
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
267
268
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
269
            do_sample=True,
270
271
272
273
274
275
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
276
277
278
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
279
280
281
282
283
284
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
285
    end = time.perf_counter()
286
287
288
    return end - start


289
def run_mii(
290
    requests: list[SampleRequest],
291
292
293
294
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
295
    from mii import client, serve
296

297
    llm = serve(model, tensor_parallel=tensor_parallel_size)
298
    prompts = [request.prompt for request in requests]
299
300

    start = time.perf_counter()
301
    llm.generate(prompts, max_new_tokens=output_len)
302
    end = time.perf_counter()
303
304
    client = client(model)
    client.terminate_server()
305
306
307
    return end - start


308
309
310
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
311
312
313
314
315
316
317
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
318
319
320
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
321
322
323
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
324
        write_to_json(pt_file, pt_records)
325
326


327
328
329
330
331
332
333
334
335
336
337
338
339
340
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
341

342
343
344
345
346
347
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
348
349
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
350
351
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
352
353
            "Tokenizer/model must have chat template for sonnet dataset."
        )
354
355
356
357
358
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
359
    elif args.dataset_name == "hf":
360
        common_kwargs["no_stream"] = args.no_stream
361
362
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
363
364
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
365
            sample_kwargs["enable_multimodal_chat"] = True
366
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
367
            dataset_cls = InstructCoderDataset
368
            common_kwargs["dataset_split"] = "train"
369
370
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
371
372
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
373
            sample_kwargs["enable_multimodal_chat"] = True
374
375
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
376
377
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
378
379
380
381
382
383
384
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


385
def main(args: argparse.Namespace):
386
387
    if args.seed is None:
        args.seed = 0
388
389
390
    print(args)
    random.seed(args.seed)
    # Sample the requests.
391
    tokenizer = AutoTokenizer.from_pretrained(
392
393
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
394
    requests = get_requests(args, tokenizer)
395
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
396
    request_outputs: Optional[list[RequestOutput]] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
397
    if args.backend == "vllm":
398
        if args.async_engine:
399
400
401
402
403
404
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
405
                    args.disable_detokenize,
406
407
                )
            )
408
        else:
409
            elapsed_time, request_outputs = run_vllm(
410
411
412
413
414
                requests,
                args.n,
                EngineArgs.from_cli_args(args),
                args.disable_detokenize,
            )
415
416
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
417
418
419
420
421
422
423
424
425
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
        )
426
    elif args.backend == "mii":
427
428
429
        elapsed_time = run_mii(
            requests, args.model, args.tensor_parallel_size, args.output_len
        )
430
431
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
432
433
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
434
435
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
436
437
438
439
440
441
442
443
444

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
445
446
447
448
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
449
450
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
451
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
452
453
454
455
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
456
457
458
459
460
461
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
462
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
463
464
        # vllm-chat backend counts the image tokens now

465
466
467
468
469
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
470
471
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
472

473
474
475
476
477
478
479
480
481
482
483
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
484
        save_to_pytorch_benchmark_format(args, results)
485

486

487
488
489
490
491
492
493
494
495
496
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
497
498
            stacklevel=2,
        )
499
500
501
502
503
504
505
506
507
508
509
510
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
511
512
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
513
514
515
516
517
518
519
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
520
521
522
523
524
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
525
                since --dataset-name is not 'hf'.",
526
527
            stacklevel=2,
        )
528
    elif args.dataset_name == "hf":
529
        if args.dataset_path in (
530
531
532
533
534
535
536
537
538
539
540
541
542
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
            )  # noqa: E501
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
            )  # noqa: E501
543
        else:
544
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
545
546

    # --random-range-ratio: only used when dataset_name is 'random'
547
548
549
    if args.dataset_name != "random" and args.random_range_ratio is not None:
        warnings.warn(
            "--random-range-ratio will be ignored since \
550
                --dataset-name is not 'random'.",
551
552
            stacklevel=2,
        )
553
554
555

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
556
557
558
559
560
561
    if (
        args.dataset_name not in {"random", "sonnet", None}
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
562
                 is not 'random', 'sonnet', or not set.",
563
564
            stacklevel=2,
        )
565
566
567

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
568
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
569
570
571
572
573
574
575
576
577
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

578
579
580
581
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
582
583
584
585
586
587
588
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
589
        raise ValueError("Tokenizer must be the same as the model for MII backend.")
590

591
592
593
594
595
    # --data-parallel is not supported currently.
    # https://github.com/vllm-project/vllm/issues/16222
    if args.data_parallel_size > 1:
        raise ValueError(
            "Data parallel is not supported in offline benchmark, \
596
597
            please use benchmark serving instead"
        )
598

599

600
def create_argument_parser():
601
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
602
603
604
605
606
607
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
608
609
610
611
612
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
613
614
        default="sharegpt",
    )
615
616
617
618
619
    parser.add_argument(
        "--no-stream",
        action="store_true",
        help="Do not load the dataset in streaming mode.",
    )
620
621
622
623
624
625
626
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
    parser.add_argument(
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
648
    parser.add_argument(
649
650
651
652
653
654
655
656
657
658
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
    parser.add_argument(
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
    parser.add_argument(
        "--output-json",
659
660
        type=str,
        default=None,
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
    parser.add_argument(
        "--disable-frontend-multiprocessing",
        action="store_true",
        default=False,
        help="Disable decoupled async engine frontend.",
    )
675
676
677
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
678
679
680
681
682
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
683
684
685
686
687
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
688
        help="Path to the LoRA adapters to use. This can be an absolute path, "
689
690
        "a relative path, or a Hugging Face model identifier.",
    )
691
692
693
    parser.add_argument(
        "--prefix-len",
        type=int,
694
695
696
697
698
699
700
701
702
703
        default=None,
        help=f"Number of prefix tokens to be used in RandomDataset "
        "and SonnetDataset. For RandomDataset, the total input "
        "length is the sum of prefix-len (default: "
        f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
        "sampled from [input_len * (1 - range_ratio), "
        "input_len * (1 + range_ratio)]. For SonnetDataset, "
        f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
        "controls how much of the input is fixed lines versus "
        "random lines, but the total input length remains approximately "
704
705
        "input_len tokens.",
    )
706
707
708
709
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
710
711
712
713
714
        default=None,
        help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
        "for sampling input/output length, "
        "used only for RandomDataset. Must be in the range [0, 1) to "
        "define a symmetric sampling range "
715
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
716
    )
717

718
    # hf dtaset
719
720
721
722
723
724
    parser.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    parser.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
725

726
    parser = AsyncEngineArgs.add_cli_args(parser)
727
728
729
730
731
732

    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
733
    args = parser.parse_args()
734
735
    if args.tokenizer is None:
        args.tokenizer = args.model
736
    validate_args(args)
737
    main(args)