"vscode:/vscode.git/clone" did not exist on "3b81dd6c109471957ae43bb769393be16db2e9a6"
benchmark_throughput.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark offline inference throughput."""
3

4
import argparse
5
import dataclasses
6
import json
7
import os
8
9
import random
import time
10
import warnings
11
from typing import Any, Optional, Union
12

13
import torch
14
import uvloop
15
from tqdm import tqdm
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from benchmark_dataset import (
    AIMODataset,
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
30
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
31
from vllm.entrypoints.openai.api_server import (
32
33
    build_async_engine_client_from_engine_args,
)
34
from vllm.inputs import TextPrompt, TokensPrompt
35
from vllm.lora.request import LoRARequest
36
from vllm.outputs import RequestOutput
37
from vllm.sampling_params import BeamSearchParams
38
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
39

40

Woosuk Kwon's avatar
Woosuk Kwon committed
41
def run_vllm(
42
    requests: list[SampleRequest],
43
    n: int,
44
    engine_args: EngineArgs,
45
    disable_detokenize: bool = False,
46
) -> tuple[float, Optional[list[RequestOutput]]]:
47
    from vllm import LLM, SamplingParams
48

49
    llm = LLM(**dataclasses.asdict(engine_args))
50
    assert all(
51
52
53
54
55
56
57
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
Zhuohan Li's avatar
Zhuohan Li committed
58
    # Add the requests to the engine.
59
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
60
    sampling_params: list[SamplingParams] = []
61
    for request in requests:
62
        prompts.append(
63
64
65
66
67
68
69
70
71
            TokensPrompt(
                prompt_token_ids=request.prompt["prompt_token_ids"],
                multi_modal_data=request.multi_modal_data,
            )
            if "prompt_token_ids" in request.prompt
            else TextPrompt(
                prompt=request.prompt, multi_modal_data=request.multi_modal_data
            )
        )
72
73
74
        sampling_params.append(
            SamplingParams(
                n=n,
75
                temperature=1.0,
76
77
                top_p=1.0,
                ignore_eos=True,
78
                max_tokens=request.expected_output_len,
79
                detokenize=not disable_detokenize,
80
81
            )
        )
82
    lora_requests: Optional[list[LoRARequest]] = None
83
84
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
85

86
87
    use_beam_search = False

88
    outputs = None
89
    if not use_beam_search:
90
        start = time.perf_counter()
91
92
93
        outputs = llm.generate(
            prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
        )
94
95
        end = time.perf_counter()
    else:
96
        assert lora_requests is None, "BeamSearch API does not support LoRA"
97
        prompts = [request.prompt for request in requests]
98
99
        # output_len should be the same for all requests.
        output_len = requests[0][2]
100
101
        for request in requests:
            assert request.expected_output_len == output_len
102
        start = time.perf_counter()
103
104
105
106
107
108
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
109
110
            ),
        )
111
        end = time.perf_counter()
112
113
114
115
    return end - start, outputs


def run_vllm_chat(
116
117
118
119
120
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
121
122
123
124
125
126
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
127

128
129
130
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
131
132
133
134
135
136
137
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
138
139
140
141
142
143
144
145
146
147
148
149
150

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
151
152
            )
        )
153
154
155
156
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
157
158


159
async def run_vllm_async(
160
    requests: list[SampleRequest],
161
    n: int,
162
    engine_args: AsyncEngineArgs,
163
    disable_frontend_multiprocessing: bool = False,
164
    disable_detokenize: bool = False,
165
166
167
168
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
169
170
        engine_args, disable_frontend_multiprocessing
    ) as llm:
171
        model_config = await llm.get_model_config()
172
        assert all(
173
174
175
176
177
178
179
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
180
181

        # Add the requests to the engine.
182
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
183
184
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
185
        for request in requests:
186
            prompts.append(
187
188
189
190
191
192
193
194
195
                TokensPrompt(
                    prompt_token_ids=request.prompt["prompt_token_ids"],
                    multi_modal_data=request.multi_modal_data,
                )
                if "prompt_token_ids" in request.prompt
                else TextPrompt(
                    prompt=request.prompt, multi_modal_data=request.multi_modal_data
                )
            )
196
197
198
            sampling_params.append(
                SamplingParams(
                    n=n,
199
                    temperature=1.0,
200
201
                    top_p=1.0,
                    ignore_eos=True,
202
                    max_tokens=request.expected_output_len,
203
                    detokenize=not disable_detokenize,
204
205
                )
            )
206
            lora_requests.append(request.lora_request)
207
208
209

        generators = []
        start = time.perf_counter()
210
211
212
213
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
214
215
216
217
218
219
220
221
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


222
def run_hf(
223
    requests: list[SampleRequest],
224
225
226
227
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
228
    trust_remote_code: bool,
229
    disable_detokenize: bool = False,
230
) -> float:
231
    llm = AutoModelForCausalLM.from_pretrained(
232
233
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
    )
234
235
236
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
237
238
239
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
240
    start = time.perf_counter()
241
    batch: list[str] = []
242
243
244
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
245
246
247
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
248
249
250
251
252
253
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
254
255
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
256
257
258
259
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
260
261
262
263
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
264
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
265
266
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
267
            do_sample=True,
268
269
270
271
272
273
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
274
275
276
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
277
278
279
280
281
282
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
283
    end = time.perf_counter()
284
285
286
    return end - start


287
def run_mii(
288
    requests: list[SampleRequest],
289
290
291
292
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
293
    from mii import client, serve
294

295
    llm = serve(model, tensor_parallel=tensor_parallel_size)
296
    prompts = [request.prompt for request in requests]
297
298

    start = time.perf_counter()
299
    llm.generate(prompts, max_new_tokens=output_len)
300
    end = time.perf_counter()
301
302
    client = client(model)
    client.terminate_server()
303
304
305
    return end - start


306
307
308
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
309
310
311
312
313
314
315
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
316
317
318
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
319
320
321
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
322
        write_to_json(pt_file, pt_records)
323
324


325
326
327
328
329
330
331
332
333
334
335
336
337
338
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
339

340
341
342
343
344
345
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
346
347
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
348
349
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
350
351
            "Tokenizer/model must have chat template for sonnet dataset."
        )
352
353
354
355
356
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
357
    elif args.dataset_name == "hf":
358
359
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
360
361
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
362
            sample_kwargs["enable_multimodal_chat"] = True
363
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
364
            dataset_cls = InstructCoderDataset
365
            common_kwargs["dataset_split"] = "train"
366
367
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
368
369
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
370
            sample_kwargs["enable_multimodal_chat"] = True
371
372
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
373
374
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
375
376
377
378
379
380
381
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


382
def main(args: argparse.Namespace):
383
384
    if args.seed is None:
        args.seed = 0
385
386
387
    print(args)
    random.seed(args.seed)
    # Sample the requests.
388
    tokenizer = AutoTokenizer.from_pretrained(
389
390
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
391
    requests = get_requests(args, tokenizer)
392
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
393
    request_outputs: Optional[list[RequestOutput]] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
394
    if args.backend == "vllm":
395
        if args.async_engine:
396
397
398
399
400
401
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
402
                    args.disable_detokenize,
403
404
                )
            )
405
        else:
406
            elapsed_time, request_outputs = run_vllm(
407
408
409
410
411
                requests,
                args.n,
                EngineArgs.from_cli_args(args),
                args.disable_detokenize,
            )
412
413
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
414
415
416
417
418
419
420
421
422
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
        )
423
    elif args.backend == "mii":
424
425
426
        elapsed_time = run_mii(
            requests, args.model, args.tensor_parallel_size, args.output_len
        )
427
428
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
429
430
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
431
432
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
433
434
435
436
437
438
439
440
441

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
442
443
444
445
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
446
447
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
448
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
449
450
451
452
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
453
454
455
456
457
458
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
459
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
460
461
        # vllm-chat backend counts the image tokens now

462
463
464
465
466
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
467
468
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
469

470
471
472
473
474
475
476
477
478
479
480
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
481
        save_to_pytorch_benchmark_format(args, results)
482

483

484
485
486
487
488
489
490
491
492
493
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
494
495
            stacklevel=2,
        )
496
497
498
499
500
501
502
503
504
505
506
507
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
508
509
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
510
511
512
513
514
515
516
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
517
518
519
520
521
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
522
                since --dataset-name is not 'hf'.",
523
524
            stacklevel=2,
        )
525
    elif args.dataset_name == "hf":
526
        if args.dataset_path in (
527
528
529
530
531
532
533
534
535
536
537
538
539
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
            )  # noqa: E501
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
            )  # noqa: E501
540
        else:
541
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
542
543

    # --random-range-ratio: only used when dataset_name is 'random'
544
545
546
    if args.dataset_name != "random" and args.random_range_ratio is not None:
        warnings.warn(
            "--random-range-ratio will be ignored since \
547
                --dataset-name is not 'random'.",
548
549
            stacklevel=2,
        )
550
551
552

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
553
554
555
556
557
558
    if (
        args.dataset_name not in {"random", "sonnet", None}
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
559
                 is not 'random', 'sonnet', or not set.",
560
561
            stacklevel=2,
        )
562
563
564

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
565
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
566
567
568
569
570
571
572
573
574
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

575
576
577
578
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
579
580
581
582
583
584
585
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
586
        raise ValueError("Tokenizer must be the same as the model for MII backend.")
587

588
589
590
591
592
    # --data-parallel is not supported currently.
    # https://github.com/vllm-project/vllm/issues/16222
    if args.data_parallel_size > 1:
        raise ValueError(
            "Data parallel is not supported in offline benchmark, \
593
594
            please use benchmark serving instead"
        )
595

596

597
if __name__ == "__main__":
598
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
599
600
601
602
603
604
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
605
606
607
608
609
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
610
611
        default="sharegpt",
    )
612
613
614
615
616
617
618
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
    parser.add_argument(
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
640
    parser.add_argument(
641
642
643
644
645
646
647
648
649
650
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
    parser.add_argument(
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
    parser.add_argument(
        "--output-json",
651
652
        type=str,
        default=None,
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
    parser.add_argument(
        "--disable-frontend-multiprocessing",
        action="store_true",
        default=False,
        help="Disable decoupled async engine frontend.",
    )
667
668
669
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
670
671
672
673
674
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
675
676
677
678
679
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
680
        help="Path to the LoRA adapters to use. This can be an absolute path, "
681
682
        "a relative path, or a Hugging Face model identifier.",
    )
683
684
685
    parser.add_argument(
        "--prefix-len",
        type=int,
686
687
688
689
690
691
692
693
694
695
        default=None,
        help=f"Number of prefix tokens to be used in RandomDataset "
        "and SonnetDataset. For RandomDataset, the total input "
        "length is the sum of prefix-len (default: "
        f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
        "sampled from [input_len * (1 - range_ratio), "
        "input_len * (1 + range_ratio)]. For SonnetDataset, "
        f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
        "controls how much of the input is fixed lines versus "
        "random lines, but the total input length remains approximately "
696
697
        "input_len tokens.",
    )
698
699
700
701
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
702
703
704
705
706
        default=None,
        help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
        "for sampling input/output length, "
        "used only for RandomDataset. Must be in the range [0, 1) to "
        "define a symmetric sampling range "
707
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
708
    )
709

710
    # hf dtaset
711
712
713
714
715
716
    parser.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    parser.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
717

718
    parser = AsyncEngineArgs.add_cli_args(parser)
719
    args = parser.parse_args()
720
721
    if args.tokenizer is None:
        args.tokenizer = args.model
722
    validate_args(args)
723
    main(args)