benchmark_throughput.py 25.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Benchmark offline inference throughput."""
4

5
import argparse
6
import dataclasses
7
import json
8
import os
9
10
import random
import time
11
import warnings
12
from typing import Any, Optional, Union
13

14
import torch
15
import uvloop
16
from tqdm import tqdm
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from benchmark_dataset import (
    AIMODataset,
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
31
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
32
from vllm.entrypoints.openai.api_server import (
33
34
    build_async_engine_client_from_engine_args,
)
35
from vllm.inputs import TextPrompt, TokensPrompt
36
from vllm.lora.request import LoRARequest
37
from vllm.outputs import RequestOutput
38
from vllm.sampling_params import BeamSearchParams
39
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
40

41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
def run_vllm(
43
    requests: list[SampleRequest],
44
    n: int,
45
    engine_args: EngineArgs,
46
    disable_detokenize: bool = False,
47
) -> tuple[float, Optional[list[RequestOutput]]]:
48
    from vllm import LLM, SamplingParams
49

50
    llm = LLM(**dataclasses.asdict(engine_args))
51
    assert all(
52
53
54
55
56
57
58
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
Zhuohan Li's avatar
Zhuohan Li committed
59
    # Add the requests to the engine.
60
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
61
    sampling_params: list[SamplingParams] = []
62
    for request in requests:
63
        prompts.append(
64
65
66
67
68
69
70
71
72
            TokensPrompt(
                prompt_token_ids=request.prompt["prompt_token_ids"],
                multi_modal_data=request.multi_modal_data,
            )
            if "prompt_token_ids" in request.prompt
            else TextPrompt(
                prompt=request.prompt, multi_modal_data=request.multi_modal_data
            )
        )
73
74
75
        sampling_params.append(
            SamplingParams(
                n=n,
76
                temperature=1.0,
77
78
                top_p=1.0,
                ignore_eos=True,
79
                max_tokens=request.expected_output_len,
80
                detokenize=not disable_detokenize,
81
82
            )
        )
83
    lora_requests: Optional[list[LoRARequest]] = None
84
85
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
86

87
88
    use_beam_search = False

89
    outputs = None
90
    if not use_beam_search:
91
        start = time.perf_counter()
92
93
94
        outputs = llm.generate(
            prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
        )
95
96
        end = time.perf_counter()
    else:
97
        assert lora_requests is None, "BeamSearch API does not support LoRA"
98
        prompts = [request.prompt for request in requests]
99
        # output_len should be the same for all requests.
100
        output_len = requests[0].expected_output_len
101
102
        for request in requests:
            assert request.expected_output_len == output_len
103
        start = time.perf_counter()
104
105
106
107
108
109
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
110
111
            ),
        )
112
        end = time.perf_counter()
113
114
115
116
    return end - start, outputs


def run_vllm_chat(
117
118
119
120
121
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
122
123
124
125
126
127
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
128

129
130
131
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
132
133
134
135
136
137
138
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
139
140
141
142
143
144
145
146
147
148
149
150
151

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
152
153
            )
        )
154
155
156
157
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
158
159


160
async def run_vllm_async(
161
    requests: list[SampleRequest],
162
    n: int,
163
    engine_args: AsyncEngineArgs,
164
    disable_frontend_multiprocessing: bool = False,
165
    disable_detokenize: bool = False,
166
167
168
169
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
170
171
        engine_args, disable_frontend_multiprocessing
    ) as llm:
172
        model_config = await llm.get_model_config()
173
        assert all(
174
175
176
177
178
179
180
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
181
182

        # Add the requests to the engine.
183
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
184
185
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
186
        for request in requests:
187
            prompts.append(
188
189
190
191
192
193
194
195
196
                TokensPrompt(
                    prompt_token_ids=request.prompt["prompt_token_ids"],
                    multi_modal_data=request.multi_modal_data,
                )
                if "prompt_token_ids" in request.prompt
                else TextPrompt(
                    prompt=request.prompt, multi_modal_data=request.multi_modal_data
                )
            )
197
198
199
            sampling_params.append(
                SamplingParams(
                    n=n,
200
                    temperature=1.0,
201
202
                    top_p=1.0,
                    ignore_eos=True,
203
                    max_tokens=request.expected_output_len,
204
                    detokenize=not disable_detokenize,
205
206
                )
            )
207
            lora_requests.append(request.lora_request)
208
209
210

        generators = []
        start = time.perf_counter()
211
212
213
214
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
215
216
217
218
219
220
221
222
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


223
def run_hf(
224
    requests: list[SampleRequest],
225
226
227
228
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
229
    trust_remote_code: bool,
230
    disable_detokenize: bool = False,
231
) -> float:
232
    llm = AutoModelForCausalLM.from_pretrained(
233
234
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
    )
235
236
237
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
238
239
240
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
241
    start = time.perf_counter()
242
    batch: list[str] = []
243
244
245
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
246
247
248
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
249
250
251
252
253
254
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
255
256
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
257
258
259
260
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
261
262
263
264
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
265
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
266
267
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
268
            do_sample=True,
269
270
271
272
273
274
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
275
276
277
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
278
279
280
281
282
283
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
284
    end = time.perf_counter()
285
286
287
    return end - start


288
def run_mii(
289
    requests: list[SampleRequest],
290
291
292
293
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
294
    from mii import client, serve
295

296
    llm = serve(model, tensor_parallel=tensor_parallel_size)
297
    prompts = [request.prompt for request in requests]
298
299

    start = time.perf_counter()
300
    llm.generate(prompts, max_new_tokens=output_len)
301
    end = time.perf_counter()
302
303
    client = client(model)
    client.terminate_server()
304
305
306
    return end - start


307
308
309
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
310
311
312
313
314
315
316
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
317
318
319
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
320
321
322
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
323
        write_to_json(pt_file, pt_records)
324
325


326
327
328
329
330
331
332
333
334
335
336
337
338
339
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
340

341
342
343
344
345
346
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
347
348
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
349
350
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
351
352
            "Tokenizer/model must have chat template for sonnet dataset."
        )
353
354
355
356
357
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
358
    elif args.dataset_name == "hf":
359
360
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
361
362
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
363
            sample_kwargs["enable_multimodal_chat"] = True
364
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
365
            dataset_cls = InstructCoderDataset
366
            common_kwargs["dataset_split"] = "train"
367
368
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
369
370
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
371
            sample_kwargs["enable_multimodal_chat"] = True
372
373
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
374
375
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
376
377
378
379
380
381
382
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


383
def main(args: argparse.Namespace):
384
385
    if args.seed is None:
        args.seed = 0
386
387
388
    print(args)
    random.seed(args.seed)
    # Sample the requests.
389
    tokenizer = AutoTokenizer.from_pretrained(
390
391
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
392
    requests = get_requests(args, tokenizer)
393
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
394
    request_outputs: Optional[list[RequestOutput]] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
395
    if args.backend == "vllm":
396
        if args.async_engine:
397
398
399
400
401
402
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
403
                    args.disable_detokenize,
404
405
                )
            )
406
        else:
407
            elapsed_time, request_outputs = run_vllm(
408
409
410
411
412
                requests,
                args.n,
                EngineArgs.from_cli_args(args),
                args.disable_detokenize,
            )
413
414
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
415
416
417
418
419
420
421
422
423
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
        )
424
    elif args.backend == "mii":
425
426
427
        elapsed_time = run_mii(
            requests, args.model, args.tensor_parallel_size, args.output_len
        )
428
429
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
430
431
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
432
433
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
434
435
436
437
438
439
440
441
442

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
443
444
445
446
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
447
448
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
449
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
450
451
452
453
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
454
455
456
457
458
459
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
460
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
461
462
        # vllm-chat backend counts the image tokens now

463
464
465
466
467
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
468
469
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
470

471
472
473
474
475
476
477
478
479
480
481
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
482
        save_to_pytorch_benchmark_format(args, results)
483

484

485
486
487
488
489
490
491
492
493
494
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
495
496
            stacklevel=2,
        )
497
498
499
500
501
502
503
504
505
506
507
508
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
509
510
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
511
512
513
514
515
516
517
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
518
519
520
521
522
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
523
                since --dataset-name is not 'hf'.",
524
525
            stacklevel=2,
        )
526
    elif args.dataset_name == "hf":
527
        if args.dataset_path in (
528
529
530
531
532
533
534
535
536
537
538
539
540
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
            )  # noqa: E501
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
            )  # noqa: E501
541
        else:
542
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
543
544

    # --random-range-ratio: only used when dataset_name is 'random'
545
546
547
    if args.dataset_name != "random" and args.random_range_ratio is not None:
        warnings.warn(
            "--random-range-ratio will be ignored since \
548
                --dataset-name is not 'random'.",
549
550
            stacklevel=2,
        )
551
552
553

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
554
555
556
557
558
559
    if (
        args.dataset_name not in {"random", "sonnet", None}
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
560
                 is not 'random', 'sonnet', or not set.",
561
562
            stacklevel=2,
        )
563
564
565

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
566
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
567
568
569
570
571
572
573
574
575
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

576
577
578
579
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
580
581
582
583
584
585
586
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
587
        raise ValueError("Tokenizer must be the same as the model for MII backend.")
588

589
590
591
592
593
    # --data-parallel is not supported currently.
    # https://github.com/vllm-project/vllm/issues/16222
    if args.data_parallel_size > 1:
        raise ValueError(
            "Data parallel is not supported in offline benchmark, \
594
595
            please use benchmark serving instead"
        )
596

597

598
def create_argument_parser():
599
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
600
601
602
603
604
605
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
606
607
608
609
610
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
611
612
        default="sharegpt",
    )
613
614
615
616
617
618
619
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
    parser.add_argument(
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
641
    parser.add_argument(
642
643
644
645
646
647
648
649
650
651
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
    parser.add_argument(
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
    parser.add_argument(
        "--output-json",
652
653
        type=str,
        default=None,
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
    parser.add_argument(
        "--disable-frontend-multiprocessing",
        action="store_true",
        default=False,
        help="Disable decoupled async engine frontend.",
    )
668
669
670
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
671
672
673
674
675
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
676
677
678
679
680
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
681
        help="Path to the LoRA adapters to use. This can be an absolute path, "
682
683
        "a relative path, or a Hugging Face model identifier.",
    )
684
685
686
    parser.add_argument(
        "--prefix-len",
        type=int,
687
688
689
690
691
692
693
694
695
696
        default=None,
        help=f"Number of prefix tokens to be used in RandomDataset "
        "and SonnetDataset. For RandomDataset, the total input "
        "length is the sum of prefix-len (default: "
        f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
        "sampled from [input_len * (1 - range_ratio), "
        "input_len * (1 + range_ratio)]. For SonnetDataset, "
        f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
        "controls how much of the input is fixed lines versus "
        "random lines, but the total input length remains approximately "
697
698
        "input_len tokens.",
    )
699
700
701
702
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
703
704
705
706
707
        default=None,
        help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
        "for sampling input/output length, "
        "used only for RandomDataset. Must be in the range [0, 1) to "
        "define a symmetric sampling range "
708
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
709
    )
710

711
    # hf dtaset
712
713
714
715
716
717
    parser.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    parser.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
718

719
    parser = AsyncEngineArgs.add_cli_args(parser)
720
721
722
723
724
725

    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
726
    args = parser.parse_args()
727
728
    if args.tokenizer is None:
        args.tokenizer = args.model
729
    validate_args(args)
730
    main(args)