"examples/tool_chat_template_hermes.jinja" did not exist on "561d6f8077c54c7af5dbf2ed92131ce9f7d9b56b"
benchmark_throughput.py 21.6 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Benchmark offline inference throughput."""
import argparse
zhuwenwen's avatar
zhuwenwen committed
4
import dataclasses
5
import json
zhuwenwen's avatar
zhuwenwen committed
6
import os
7
8
import random
import time
9
from functools import cache
zhuwenwen's avatar
zhuwenwen committed
10
from typing import Any, Dict, List, Optional, Tuple
11
12
13

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
14
import uvloop
zhuwenwen's avatar
zhuwenwen committed
15
from benchmark_utils import convert_to_pytorch_benchmark_format
zhuwenwen's avatar
zhuwenwen committed
16
from PIL import Image
17
18
19
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
zhuwenwen's avatar
zhuwenwen committed
20

zhuwenwen's avatar
zhuwenwen committed
21
22
23

from vllm.inputs import PromptType
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
24
25
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
zhuwenwen's avatar
zhuwenwen committed
26
from vllm.inputs import TextPrompt
27
28
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
zhuwenwen's avatar
zhuwenwen committed
29
30
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
31
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
zhuwenwen's avatar
zhuwenwen committed
32
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
33
34


zhuwenwen's avatar
zhuwenwen committed
35
36
37
38
39
40
41
42
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
43
44
45
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        lora_request: Optional LoRARequest specifying the LoRA to use. 
zhuwenwen's avatar
zhuwenwen committed
46
47
48
49
50
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None
51
    lora_request: Optional[LoRARequest] = None
zhuwenwen's avatar
zhuwenwen committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
    """Prepend and append special tokens around the question to form a prompt.

    Args:
        question: The input question text to wrap with special tokens
        model: The name of the model being used, to determine which special
            tokens to add

    Returns:
        The formatted prompt string with appropriate special tokens for the
            model

    Raises:
        ValueError: If an unsupported model name is provided
    """
    model = model.lower()
    if "pixtral" in model:
        return f"<s>[INST]{question}\n[IMG][/INST]"
    raise ValueError(f"Unsupported model {model}")


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}


def get_random_lora_request(
        args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
    global lora_tokenizer_cache
    lora_id = random.randint(1, args.max_loras)
    lora_request = LoRARequest(lora_name=str(lora_id),
                               lora_int_id=lora_id,
                               lora_path=lora_path_on_disk(args.lora_path))
    if lora_id not in lora_tokenizer_cache:
        lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
    return lora_request, lora_tokenizer_cache[lora_id]


zhuwenwen's avatar
zhuwenwen committed
96
97
def sample_requests(tokenizer: PreTrainedTokenizerBase,
                    args: argparse.Namespace) -> List[SampleRequest]:
98

zhuwenwen's avatar
zhuwenwen committed
99
100
101
102
    dataset_path: str = args.dataset
    num_requests: int = args.num_prompts
    fixed_output_len: Optional[int] = args.output_len
    model: str = args.model
103
104
105
106
107
108
109
110
111
112
113
114
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
zhuwenwen's avatar
zhuwenwen committed
115
    filtered_dataset: List[SampleRequest] = []
116
117
118
    for data in tqdm(dataset,
                     total=len(filtered_dataset),
                     desc="sampling requests"):
119
120
121
        if len(filtered_dataset) == num_requests:
            break

zhuwenwen's avatar
zhuwenwen committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        # Only keep the first two turns of each conversation.
        prompt = data["conversations"][0]["value"]
        completion = data["conversations"][1]["value"]

        multi_modal_data: Optional[MultiModalDataDict] = None
        if "image" in data:
            multi_modal_data = multi_modal_data or {}
            image_path = data["image"]
            # TODO(vllm-project/vllm/issues/9778): Support multiple images.
            assert isinstance(image_path,
                              str), "Only support single image input"
            try:
                multi_modal_data["image"] = Image.open(image_path).convert(
                    "RGB")
            except FileNotFoundError:
                # Ignore datapoint where asset is missing
                continue
            prompt = _get_prompt_for_image_model(question=prompt, model=model)

141
142
143
144
145
146
147
        request_tokenizer = tokenizer
        lora_request: Optional[LoRARequest] = None
        if args.enable_lora:
            lora_request, lora_tokenizer = get_random_lora_request(args)
            if lora_tokenizer:
                request_tokenizer = lora_tokenizer

148
        # Tokenize the prompts and completions.
149
150
        prompt_token_ids = request_tokenizer(prompt).input_ids
        completion_token_ids = request_tokenizer(completion).input_ids
151
152
153
154
155
156
157
158
159
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
zhuwenwen's avatar
zhuwenwen committed
160
161
162
163
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
                          expected_output_len=output_len,
164
165
                          multi_modal_data=multi_modal_data,
                          lora_request=lora_request))
166
167
168
169
170

    return filtered_dataset


def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
171
    requests: List[SampleRequest],
172
    n: int,
173
    num_iters_warmup: int,
zhuwenwen's avatar
zhuwenwen committed
174
    engine_args: EngineArgs,
175
176
) -> float:
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
177
    llm = LLM(**dataclasses.asdict(engine_args))
178
179

    # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
180
    prompts: List[TextPrompt] = []
181
    sampling_params: List[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
182
183
184
185
    for request in requests:
        prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
186
187
188
        sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
189
                temperature=1.0,
190
191
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
192
                max_tokens=request.expected_output_len,
193
            ))
194
195
196
    lora_requests: Optional[List[LoRARequest]] = None
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
197
198

    # warmup
199
200
201
202
203
204
205
206
207
208
209
    warmup_sampling_params = SamplingParams(
        n=args.n,
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=10,
    )
    dummy_prompt_token_ids = np.random.randint(10000, size=(1,10))
    dummy_prompts: List[PromptType] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
210
    
211
212
213
214
215
    print("Warming up...")
    for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
        llm.generate(dummy_prompts,
                        sampling_params=warmup_sampling_params,
                        use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
216

zhuwenwen's avatar
zhuwenwen committed
217
218
219
    use_beam_search = False

    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
220
        start = time.perf_counter()
221
222
223
224
        llm.generate(prompts,
                     sampling_params,
                     lora_request=lora_requests,
                     use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
225
226
        end = time.perf_counter()
    else:
227
        assert lora_requests is None, "BeamSearch API does not support LoRA"
zhuwenwen's avatar
zhuwenwen committed
228
        prompts = [request.prompt for request in requests]
zhuwenwen's avatar
zhuwenwen committed
229
230
        # output_len should be the same for all requests.
        output_len = requests[0][2]
zhuwenwen's avatar
zhuwenwen committed
231
232
        for request in requests:
            assert request.expected_output_len == output_len
zhuwenwen's avatar
zhuwenwen committed
233
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
234
235
236
237
238
239
240
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
zhuwenwen's avatar
zhuwenwen committed
241
        end = time.perf_counter()
242
243
244
    return end - start


zhuwenwen's avatar
zhuwenwen committed
245
async def run_vllm_async(
zhuwenwen's avatar
zhuwenwen committed
246
    requests: List[SampleRequest],
zhuwenwen's avatar
zhuwenwen committed
247
    n: int,
zhuwenwen's avatar
zhuwenwen committed
248
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
249
250
251
252
253
254
255
256
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
257
        prompts: List[TextPrompt] = []
zhuwenwen's avatar
zhuwenwen committed
258
        sampling_params: List[SamplingParams] = []
259
        lora_requests: List[Optional[LoRARequest]] = []
zhuwenwen's avatar
zhuwenwen committed
260
261
262
263
        for request in requests:
            prompts.append(
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
zhuwenwen's avatar
zhuwenwen committed
264
265
266
            sampling_params.append(
                SamplingParams(
                    n=n,
zhuwenwen's avatar
zhuwenwen committed
267
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
268
269
                    top_p=1.0,
                    ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
270
                    max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
271
                ))
272
            lora_requests.append(request.lora_request)
zhuwenwen's avatar
zhuwenwen committed
273
274
275

        generators = []
        start = time.perf_counter()
276
277
278
279
280
281
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
zhuwenwen's avatar
zhuwenwen committed
282
283
284
285
286
287
288
289
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


290
def run_hf(
zhuwenwen's avatar
zhuwenwen committed
291
    requests: List[SampleRequest],
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
zhuwenwen's avatar
zhuwenwen committed
329
            do_sample=True,
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
zhuwenwen's avatar
zhuwenwen committed
349
    requests: List[SampleRequest],
350
351
352
353
354
355
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
zhuwenwen's avatar
zhuwenwen committed
356
    prompts = [request.prompt for request in requests]
357
358
359
360
361
362
363
364
365

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


zhuwenwen's avatar
zhuwenwen committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
                                     results: Dict[str, Any]) -> None:
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
        with open(pt_file, "w") as f:
            json.dump(pt_records, f)


385
386
387
388
389
390
391
392
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
zhuwenwen's avatar
zhuwenwen committed
393
394
395
        vocab_size = tokenizer.vocab_size
        requests = []
        for _ in range(args.num_prompts):
396
397
398
399
400
401
402
403

            request_tokenizer = tokenizer
            lora_request: Optional[LoRARequest] = None
            if args.enable_lora:
                lora_request, lora_tokenizer = get_random_lora_request(args)
                if lora_tokenizer:
                    request_tokenizer = lora_tokenizer

zhuwenwen's avatar
zhuwenwen committed
404
405
406
407
408
409
410
411
            # Synthesize a prompt with the given input length.
            candidate_ids = [
                random.randint(0, vocab_size - 1)
                for _ in range(args.input_len)
            ]
            # As tokenizer may add additional tokens like BOS, we need to try
            # different lengths to get the desired input length.
            for _ in range(5):  # Max attempts to correct
412
413
                candidate_prompt = request_tokenizer.decode(candidate_ids)
                tokenized_len = len(request_tokenizer.encode(candidate_prompt))
zhuwenwen's avatar
zhuwenwen committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

                if tokenized_len == args.input_len:
                    break

                # Adjust length based on difference
                diff = args.input_len - tokenized_len
                if diff > 0:
                    candidate_ids.extend([
                        random.randint(100, vocab_size - 100)
                        for _ in range(diff)
                    ])
                else:
                    candidate_ids = candidate_ids[:diff]
            requests.append(
                SampleRequest(prompt=candidate_prompt,
                              prompt_len=args.input_len,
430
431
                              expected_output_len=args.output_len,
                              lora_request=lora_request))
432
    else:
zhuwenwen's avatar
zhuwenwen committed
433
        requests = sample_requests(tokenizer, args)
434

zhuwenwen's avatar
zhuwenwen committed
435
436
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
437
    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
438
        if args.async_engine:
zhuwenwen's avatar
zhuwenwen committed
439
440
441
442
443
444
445
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
zhuwenwen's avatar
zhuwenwen committed
446
        else:
447
            elapsed_time = run_vllm(requests, args.n, args.num_iters_warmup,
zhuwenwen's avatar
zhuwenwen committed
448
                                    EngineArgs.from_cli_args(args))
449
450
451
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
zhuwenwen's avatar
zhuwenwen committed
452
                              args.hf_max_batch_size, args.trust_remote_code)
453
454
455
456
457
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
zhuwenwen's avatar
zhuwenwen committed
458
459
460
461
462
463
464
465
466
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                            for request in requests)
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
        # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
zhuwenwen's avatar
zhuwenwen committed
467
    print(f"Latency: {elapsed_time:.2f} s")
zhuwenwen's avatar
zhuwenwen committed
468
469
470
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
471
472
473
474
475
476
477
478
479
480
481
482

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
zhuwenwen's avatar
zhuwenwen committed
483
        save_to_pytorch_benchmark_format(args, results)
484
485
486


if __name__ == "__main__":
487
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
488
489
490
491
492
493
494
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf", "mii"],
                        default="vllm")
    parser.add_argument("--dataset",
                        type=str,
                        default=None,
zhuwenwen's avatar
zhuwenwen committed
495
496
497
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
                        help="Maximum batch size for HF backend.")
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
zhuwenwen's avatar
zhuwenwen committed
528
529
530
531
532
533
534
535
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
536
537
538
539
540
541
542
543
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")

zhuwenwen's avatar
zhuwenwen committed
544
    parser = AsyncEngineArgs.add_cli_args(parser)
545
546
547
548
549
550
551
552
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
553
554
    if args.enable_lora:
        assert args.lora_path is not None
555
556
557
558
559
560
561
562
563

    if args.backend == "vllm":
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
564
565
566
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
567
568
569
570
571
572
573
574
575
576
577
578
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
579
580
581
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
582
    main(args)