benchmark_throughput.py 23.6 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

zhuwenwen's avatar
zhuwenwen committed
8
import numpy as np
9
import torch
10
import uvloop
11
from tqdm import tqdm
12
13
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
14

15
from vllm.inputs import PromptInputs
16
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
17
18
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
19
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
20
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
21

22
23
24
25
26

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
27
    fixed_output_len: Optional[int],
28
) -> List[Tuple[str, int, int]]:
29
30
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
31

32
33
34
35
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
36
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
37
    # Only keep the first two turns of each conversation.
38
39
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
40

41
42
    # Shuffle the dataset.
    random.shuffle(dataset)
43

44
    # Filter out sequences that are too long or too short
45
    filtered_dataset: List[Tuple[str, int, int]] = []
46
47
48
49
50
51
52
53
54
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
55
        prompt_len = len(prompt_token_ids)
56
57
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
58
59
60
61
62
63
64
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
65

66
    return filtered_dataset
67
68


Woosuk Kwon's avatar
Woosuk Kwon committed
69
def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
70
    warmup_requests: List[Tuple[str, int, int]],
71
72
    requests: List[Tuple[str, int, int]],
    model: str,
73
    tokenizer: str,
74
    quantization: Optional[str],
75
76
77
78
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
79
    trust_remote_code: bool,
80
    dtype: str,
81
82
    max_model_len: Optional[int],
    enforce_eager: bool,
83
    kv_cache_dtype: str,
84
    quantization_param_path: Optional[str],
85
    device: str,
86
    enable_prefix_caching: bool,
87
88
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
89
    distributed_executor_backend: Optional[str],
90
    gpu_memory_utilization: float = 0.9,
91
92
    num_scheduler_steps: int = 1,
    use_v2_block_manager: bool = False,
93
    download_dir: Optional[str] = None,
94
    load_format: str = EngineArgs.load_format,
95
    disable_async_output_proc: bool = False,
96
) -> float:
97
    from vllm import LLM, SamplingParams
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
116
        distributed_executor_backend=distributed_executor_backend,
117
        load_format=load_format,
118
119
        num_scheduler_steps=num_scheduler_steps,
        use_v2_block_manager=use_v2_block_manager,
120
        disable_async_output_proc=disable_async_output_proc,
121
    )
122

Zhuohan Li's avatar
Zhuohan Li committed
123
    # Add the requests to the engine.
124
125
    prompts: List[str] = []
    sampling_params: List[SamplingParams] = []
126
    for prompt, _, output_len in requests:
127
128
129
130
131
132
133
134
135
136
        prompts.append(prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            ))
137

zhuwenwen's avatar
zhuwenwen committed
138
    # warmup
zhuwenwen's avatar
zhuwenwen committed
139
140
141
142
143
144
145
146
147
148
149
150
151
    warmup_prompts = []
    warmup_sampling_params = []
    for prompt, _, output_len in warmup_requests:
        warmup_prompts.append(prompt)
        warmup_sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            ))
152
        
zhuwenwen's avatar
zhuwenwen committed
153
154
    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
155
        llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
156
157
158
159
    
    # dummy_prompt_token_ids = np.random.randint(10000,
    #                                            size=(args.num_prompts,
    #                                                  args.input_len))
160
    # dummy_inputs: List[PromptInputs] = [{
zhuwenwen's avatar
zhuwenwen committed
161
162
    #     "prompt_token_ids": batch
    # } for batch in dummy_prompt_token_ids.tolist()]
zhuwenwen's avatar
zhuwenwen committed
163

zhuwenwen's avatar
zhuwenwen committed
164
165
166
167
    # def run_to_completion():
    #     llm.generate(dummy_inputs,
    #                     sampling_params=sampling_params,
    #                     use_tqdm=False)
168

zhuwenwen's avatar
zhuwenwen committed
169
170
171
    # print("Warming up...")
    # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
    #     run_to_completion()
zhuwenwen's avatar
zhuwenwen committed
172
    
173
    start = time.perf_counter()
174
    llm.generate(prompts, sampling_params, use_tqdm=True)
175
    end = time.perf_counter()
176
177
178
    return end - start


179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
async def run_vllm_async(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: str,
    quantization: Optional[str],
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
    trust_remote_code: bool,
    dtype: str,
    max_model_len: Optional[int],
    enforce_eager: bool,
    kv_cache_dtype: str,
    quantization_param_path: Optional[str],
    device: str,
    enable_prefix_caching: bool,
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
    distributed_executor_backend: Optional[str],
    gpu_memory_utilization: float = 0.9,
    num_scheduler_steps: int = 1,
    use_v2_block_manager: bool = False,
    download_dir: Optional[str] = None,
    load_format: str = EngineArgs.load_format,
    disable_async_output_proc: bool = False,
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams
    engine_args = AsyncEngineArgs(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
        distributed_executor_backend=distributed_executor_backend,
        load_format=load_format,
        num_scheduler_steps=num_scheduler_steps,
        use_v2_block_manager=use_v2_block_manager,
        disable_async_output_proc=disable_async_output_proc,
        worker_use_ray=False,
        engine_use_ray=False,
        disable_log_requests=True,
    )

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
        prompts: List[str] = []
        sampling_params: List[SamplingParams] = []
        for prompt, _, output_len in requests:
            prompts.append(prompt)
            sampling_params.append(
                SamplingParams(
                    n=n,
                    temperature=0.0 if use_beam_search else 1.0,
                    top_p=1.0,
                    use_beam_search=use_beam_search,
                    ignore_eos=True,
                    max_tokens=output_len,
                ))

        generators = []
        start = time.perf_counter()
        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
            generator = llm.generate(prompt, sp, request_id=f"test{i}")
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


266
267
268
269
270
271
272
def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
273
    trust_remote_code: bool,
274
275
) -> float:
    assert not use_beam_search
276
277
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
278
279
280
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
281
282
283
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
284
    start = time.perf_counter()
285
286
287
288
289
290
291
292
293
294
295
296
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
297
298
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
299
300
301
302
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
303
304
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
322
    end = time.perf_counter()
323
324
325
    return end - start


326
327
328
329
330
331
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
332
333
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
334
335
336
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
337
    llm.generate(prompts, max_new_tokens=output_len)
338
    end = time.perf_counter()
339
340
    client = client(model)
    client.terminate_server()
341
342
343
    return end - start


344
345
346
347
348
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
349
350
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
351
352
353
    warmup_prompt = "hi" * 10
    warmup_requests = [(warmup_prompt, 10, 10)
                for _ in range(1)]
354
355
356
357
358
359
360
361
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
362

Woosuk Kwon's avatar
Woosuk Kwon committed
363
    if args.backend == "vllm":
364
        run_args = [
365
            warmup_requests, requests, args.model, args.tokenizer, args.quantization,
366
367
368
369
370
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
            args.trust_remote_code, args.dtype, args.max_model_len,
            args.enforce_eager, args.kv_cache_dtype,
            args.quantization_param_path, args.device,
            args.enable_prefix_caching, args.enable_chunked_prefill,
371
            args.max_num_batched_tokens, args.distributed_executor_backend,
372
            args.gpu_memory_utilization, args.num_scheduler_steps,
373
            args.use_v2_block_manager, args.download_dir, args.load_format,
374
375
376
377
378
379
380
381
            args.disable_async_output_proc
        ]

        if args.async_engine:
            run_args.append(args.disable_frontend_multiprocessing)
            elapsed_time = uvloop.run(run_vllm_async(*run_args))
        else:
            elapsed_time = run_vllm(*run_args)
382
383
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
384
385
386
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
387
388
389
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
390
391
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
392
393
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
zhuwenwen's avatar
zhuwenwen committed
394
    
zhuwenwen's avatar
zhuwenwen committed
395
396
397
398
399
400
401
402
    if args.dataset is None:
        total_out_tokens = args.output_len * args.num_prompts
    else:
        total_out_tokens = sum(output_len for _, _, output_len in requests) 
    print(f"Latency: {elapsed_time:.2f} s")
    print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
    print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
zhuwenwen's avatar
zhuwenwen committed
403

404

405
406
407
408
409
410
411
412
413
414
415
416
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

417
418

if __name__ == "__main__":
419
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
420
421
    parser.add_argument("--backend",
                        type=str,
422
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
423
                        default="vllm")
424
425
    parser.add_argument("--dataset",
                        type=str,
426
                        default=None,
427
                        help="Path to the dataset.")
428
429
430
431
432
433
434
435
436
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
437
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
438
    parser.add_argument("--tokenizer", type=str, default=None)
439
440
    parser.add_argument('--quantization',
                        '-q',
441
                        choices=[*QUANTIZATION_METHODS, None],
442
                        default=None)
443
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
444
445
446
    parser.add_argument("--n",
                        type=int,
                        default=1,
447
448
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
zhuwenwen's avatar
zhuwenwen committed
449
450
451
452
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
453
454
455
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
456
457
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
458
459
460
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
461
                        help="Maximum batch size for HF backend.")
462
463
464
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
465
466
467
468
469
470
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
471
472
473
474
475
476
477
478
479
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
480
481
482
483
484
485
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
486
487
488
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
489
    parser.add_argument(
490
        '--kv-cache-dtype',
491
        type=str,
492
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
493
        default="auto",
494
495
496
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
497
498
499
500
501
502
503
504
505
506
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
507
508
509
510
511
    parser.add_argument("--device",
                        type=str,
                        default="auto",
                        choices=DEVICE_OPTIONS,
                        help='device type for vLLM execution')
512
513
514
515
516
517
518
519
    parser.add_argument(
        "--num-scheduler-steps",
        type=int,
        default=1,
        help="Maximum number of forward steps per scheduler call.")
    parser.add_argument("--use-v2-block-manager",
                        action='store_true',
                        help="Enable block manager v2.")
520
521
522
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
523
        help="Enable automatic prefix caching for vLLM backend.")
524
525
526
527
528
529
530
531
    parser.add_argument("--enable-chunked-prefill",
                        action='store_true',
                        help="enable chunked prefill for vLLM backend.")
    parser.add_argument('--max-num-batched-tokens',
                        type=int,
                        default=None,
                        help='maximum number of batched tokens per '
                        'iteration')
532
533
534
535
536
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
537
538
539
540
541
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
542
543
544
545
546
547
548
    parser.add_argument(
        '--distributed-executor-backend',
        choices=['ray', 'mp'],
        default=None,
        help='Backend to use for distributed serving. When more than 1 GPU '
        'is used, will be automatically set to "ray" if installed '
        'or "mp" (multiprocessing) otherwise.')
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    parser.add_argument(
        '--load-format',
        type=str,
        default=EngineArgs.load_format,
        choices=[
            'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
            'bitsandbytes'
        ],
        help='The format of the model weights to load.\n\n'
        '* "auto" will try to load the weights in the safetensors format '
        'and fall back to the pytorch bin format if safetensors format '
        'is not available.\n'
        '* "pt" will load the weights in the pytorch bin format.\n'
        '* "safetensors" will load the weights in the safetensors format.\n'
        '* "npcache" will load the weights in pytorch format and store '
        'a numpy cache to speed up the loading.\n'
        '* "dummy" will initialize the weights with random values, '
        'which is mainly for profiling.\n'
        '* "tensorizer" will load the weights using tensorizer from '
        'CoreWeave. See the Tensorize vLLM Model script in the Examples'
        'section for more information.\n'
        '* "bitsandbytes" will load the weights using bitsandbytes '
        'quantization.\n')
572
573
574
575
576
    parser.add_argument(
        "--disable-async-output-proc",
        action='store_true',
        default=False,
        help="Disable async output processor for vLLM backend.")
577
578
579
580
581
582
583
584
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
585
    args = parser.parse_args()
586
587
588
589
590
591
592
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
593

Woosuk Kwon's avatar
Woosuk Kwon committed
594
    if args.backend == "vllm":
595
596
597
598
599
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
600
601
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
zhuwenwen's avatar
zhuwenwen committed
616
    main(args)