benchmark_throughput.py 23.1 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Benchmark offline inference throughput."""
import argparse
zhuwenwen's avatar
zhuwenwen committed
4
import dataclasses
5
import json
zhuwenwen's avatar
zhuwenwen committed
6
import os
7
8
import random
import time
zhuwenwen's avatar
zhuwenwen committed
9
from pathlib import Path
10
from functools import cache
zhuwenwen's avatar
zhuwenwen committed
11
from typing import Any, Dict, List, Optional, Tuple
12
13
14

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
15
import uvloop
zhuwenwen's avatar
zhuwenwen committed
16
from benchmark_utils import convert_to_pytorch_benchmark_format
zhuwenwen's avatar
zhuwenwen committed
17
from PIL import Image
18
19
20
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
zhuwenwen's avatar
zhuwenwen committed
21

zhuwenwen's avatar
zhuwenwen committed
22
23
24

from vllm.inputs import PromptType
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
25
26
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
zhuwenwen's avatar
zhuwenwen committed
27
from vllm.inputs import TextPrompt
28
29
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
zhuwenwen's avatar
zhuwenwen committed
30
31
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
32
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
zhuwenwen's avatar
zhuwenwen committed
33
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
34
35


zhuwenwen's avatar
zhuwenwen committed
36
37
38
39
40
41
42
43
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
44
45
46
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        lora_request: Optional LoRARequest specifying the LoRA to use. 
zhuwenwen's avatar
zhuwenwen committed
47
48
49
50
51
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None
52
    lora_request: Optional[LoRARequest] = None
zhuwenwen's avatar
zhuwenwen committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
    """Prepend and append special tokens around the question to form a prompt.

    Args:
        question: The input question text to wrap with special tokens
        model: The name of the model being used, to determine which special
            tokens to add

    Returns:
        The formatted prompt string with appropriate special tokens for the
            model

    Raises:
        ValueError: If an unsupported model name is provided
    """
    model = model.lower()
    if "pixtral" in model:
        return f"<s>[INST]{question}\n[IMG][/INST]"
    raise ValueError(f"Unsupported model {model}")


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}


def get_random_lora_request(
        args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
    global lora_tokenizer_cache
    lora_id = random.randint(1, args.max_loras)
    lora_request = LoRARequest(lora_name=str(lora_id),
                               lora_int_id=lora_id,
                               lora_path=lora_path_on_disk(args.lora_path))
    if lora_id not in lora_tokenizer_cache:
        lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
    return lora_request, lora_tokenizer_cache[lora_id]


zhuwenwen's avatar
zhuwenwen committed
97
98
def sample_requests(tokenizer: PreTrainedTokenizerBase,
                    args: argparse.Namespace) -> List[SampleRequest]:
99

zhuwenwen's avatar
zhuwenwen committed
100
101
102
103
    dataset_path: str = args.dataset
    num_requests: int = args.num_prompts
    fixed_output_len: Optional[int] = args.output_len
    model: str = args.model
104
105
106
107
108
109
110
111
112
113
114
115
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
zhuwenwen's avatar
zhuwenwen committed
116
    filtered_dataset: List[SampleRequest] = []
117
118
119
    for data in tqdm(dataset,
                     total=len(filtered_dataset),
                     desc="sampling requests"):
120
121
122
        if len(filtered_dataset) == num_requests:
            break

zhuwenwen's avatar
zhuwenwen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        # Only keep the first two turns of each conversation.
        prompt = data["conversations"][0]["value"]
        completion = data["conversations"][1]["value"]

        multi_modal_data: Optional[MultiModalDataDict] = None
        if "image" in data:
            multi_modal_data = multi_modal_data or {}
            image_path = data["image"]
            # TODO(vllm-project/vllm/issues/9778): Support multiple images.
            assert isinstance(image_path,
                              str), "Only support single image input"
            try:
                multi_modal_data["image"] = Image.open(image_path).convert(
                    "RGB")
            except FileNotFoundError:
                # Ignore datapoint where asset is missing
                continue
            prompt = _get_prompt_for_image_model(question=prompt, model=model)

142
143
144
145
146
147
148
        request_tokenizer = tokenizer
        lora_request: Optional[LoRARequest] = None
        if args.enable_lora:
            lora_request, lora_tokenizer = get_random_lora_request(args)
            if lora_tokenizer:
                request_tokenizer = lora_tokenizer

149
        # Tokenize the prompts and completions.
150
151
        prompt_token_ids = request_tokenizer(prompt).input_ids
        completion_token_ids = request_tokenizer(completion).input_ids
152
153
154
155
156
157
158
159
160
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
zhuwenwen's avatar
zhuwenwen committed
161
162
163
164
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
                          expected_output_len=output_len,
165
166
                          multi_modal_data=multi_modal_data,
                          lora_request=lora_request))
167
168
169
170
171

    return filtered_dataset


def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
172
    requests: List[SampleRequest],
173
    n: int,
174
    num_iters_warmup: int,
zhuwenwen's avatar
zhuwenwen committed
175
    engine_args: EngineArgs,
176
177
) -> float:
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
178
    llm = LLM(**dataclasses.asdict(engine_args))
179
180

    # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
181
    prompts: List[TextPrompt] = []
182
    sampling_params: List[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
183
184
185
186
    for request in requests:
        prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
187
188
189
        sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
190
                temperature=1.0,
191
192
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
193
                max_tokens=request.expected_output_len,
194
            ))
195
196
197
    lora_requests: Optional[List[LoRARequest]] = None
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
198
199

    # warmup
200
201
202
203
204
205
206
207
208
209
210
    warmup_sampling_params = SamplingParams(
        n=args.n,
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=10,
    )
    dummy_prompt_token_ids = np.random.randint(10000, size=(1,10))
    dummy_prompts: List[PromptType] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
211
    
212
213
214
215
216
    print("Warming up...")
    for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
        llm.generate(dummy_prompts,
                        sampling_params=warmup_sampling_params,
                        use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
217

zhuwenwen's avatar
zhuwenwen committed
218
219
220
    use_beam_search = False

    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if args.profile:
            profile_dir = args.profile_result_dir
            if not profile_dir:
                profile_dir = Path(
                    "."
                ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
            print(f"Profiling (results will be saved to '{profile_dir}')...")
            with torch.profiler.profile(
                        activities=[torch.profiler.ProfilerActivity.CPU,
                                    torch.profiler.ProfilerActivity.CUDA,
                        ],record_shapes=True,
                        on_trace_ready=torch.profiler.tensorboard_trace_handler(str(profile_dir))
                        ) as prof:
                start = time.perf_counter()
                llm.generate(prompts,
                        sampling_params,
                        lora_request=lora_requests,
                        use_tqdm=True)
                end = time.perf_counter()
            print('Prepare time report')
            print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cuda_time_total", row_limit=-1))
        else:
            start = time.perf_counter()
            llm.generate(prompts,
                        sampling_params,
                        lora_request=lora_requests,
                        use_tqdm=True)
            end = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
249
    else:
250
        assert lora_requests is None, "BeamSearch API does not support LoRA"
zhuwenwen's avatar
zhuwenwen committed
251
        prompts = [request.prompt for request in requests]
zhuwenwen's avatar
zhuwenwen committed
252
253
        # output_len should be the same for all requests.
        output_len = requests[0][2]
zhuwenwen's avatar
zhuwenwen committed
254
255
        for request in requests:
            assert request.expected_output_len == output_len
zhuwenwen's avatar
zhuwenwen committed
256
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
261
262
263
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
zhuwenwen's avatar
zhuwenwen committed
264
        end = time.perf_counter()
265
266
267
    return end - start


zhuwenwen's avatar
zhuwenwen committed
268
async def run_vllm_async(
zhuwenwen's avatar
zhuwenwen committed
269
    requests: List[SampleRequest],
zhuwenwen's avatar
zhuwenwen committed
270
    n: int,
zhuwenwen's avatar
zhuwenwen committed
271
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
272
273
274
275
276
277
278
279
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
280
        prompts: List[TextPrompt] = []
zhuwenwen's avatar
zhuwenwen committed
281
        sampling_params: List[SamplingParams] = []
282
        lora_requests: List[Optional[LoRARequest]] = []
zhuwenwen's avatar
zhuwenwen committed
283
284
285
286
        for request in requests:
            prompts.append(
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
zhuwenwen's avatar
zhuwenwen committed
287
288
289
            sampling_params.append(
                SamplingParams(
                    n=n,
zhuwenwen's avatar
zhuwenwen committed
290
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
291
292
                    top_p=1.0,
                    ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
293
                    max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
294
                ))
295
            lora_requests.append(request.lora_request)
zhuwenwen's avatar
zhuwenwen committed
296
297
298

        generators = []
        start = time.perf_counter()
299
300
301
302
303
304
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
zhuwenwen's avatar
zhuwenwen committed
305
306
307
308
309
310
311
312
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


313
def run_hf(
zhuwenwen's avatar
zhuwenwen committed
314
    requests: List[SampleRequest],
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
zhuwenwen's avatar
zhuwenwen committed
352
            do_sample=True,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
zhuwenwen's avatar
zhuwenwen committed
372
    requests: List[SampleRequest],
373
374
375
376
377
378
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
zhuwenwen's avatar
zhuwenwen committed
379
    prompts = [request.prompt for request in requests]
380
381
382
383
384
385
386
387
388

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


zhuwenwen's avatar
zhuwenwen committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
                                     results: Dict[str, Any]) -> None:
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
        with open(pt_file, "w") as f:
            json.dump(pt_records, f)


408
409
410
411
412
413
414
415
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
zhuwenwen's avatar
zhuwenwen committed
416
417
418
        vocab_size = tokenizer.vocab_size
        requests = []
        for _ in range(args.num_prompts):
419
420
421
422
423
424
425
426

            request_tokenizer = tokenizer
            lora_request: Optional[LoRARequest] = None
            if args.enable_lora:
                lora_request, lora_tokenizer = get_random_lora_request(args)
                if lora_tokenizer:
                    request_tokenizer = lora_tokenizer

zhuwenwen's avatar
zhuwenwen committed
427
428
429
430
431
432
433
434
            # Synthesize a prompt with the given input length.
            candidate_ids = [
                random.randint(0, vocab_size - 1)
                for _ in range(args.input_len)
            ]
            # As tokenizer may add additional tokens like BOS, we need to try
            # different lengths to get the desired input length.
            for _ in range(5):  # Max attempts to correct
435
436
                candidate_prompt = request_tokenizer.decode(candidate_ids)
                tokenized_len = len(request_tokenizer.encode(candidate_prompt))
zhuwenwen's avatar
zhuwenwen committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

                if tokenized_len == args.input_len:
                    break

                # Adjust length based on difference
                diff = args.input_len - tokenized_len
                if diff > 0:
                    candidate_ids.extend([
                        random.randint(100, vocab_size - 100)
                        for _ in range(diff)
                    ])
                else:
                    candidate_ids = candidate_ids[:diff]
            requests.append(
                SampleRequest(prompt=candidate_prompt,
                              prompt_len=args.input_len,
453
454
                              expected_output_len=args.output_len,
                              lora_request=lora_request))
455
    else:
zhuwenwen's avatar
zhuwenwen committed
456
        requests = sample_requests(tokenizer, args)
457

zhuwenwen's avatar
zhuwenwen committed
458
459
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
460
    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
461
        if args.async_engine:
zhuwenwen's avatar
zhuwenwen committed
462
463
464
465
466
467
468
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
zhuwenwen's avatar
zhuwenwen committed
469
        else:
470
            elapsed_time = run_vllm(requests, args.n, args.num_iters_warmup,
zhuwenwen's avatar
zhuwenwen committed
471
                                    EngineArgs.from_cli_args(args))
472
473
474
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
zhuwenwen's avatar
zhuwenwen committed
475
                              args.hf_max_batch_size, args.trust_remote_code)
476
477
478
479
480
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
zhuwenwen's avatar
zhuwenwen committed
481
482
483
484
485
486
487
488
489
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                            for request in requests)
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
        # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
zhuwenwen's avatar
zhuwenwen committed
490
    print(f"Latency: {elapsed_time:.2f} s")
zhuwenwen's avatar
zhuwenwen committed
491
492
493
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
494
495
496
497
498
499
500
501
502
503
504
505

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
zhuwenwen's avatar
zhuwenwen committed
506
        save_to_pytorch_benchmark_format(args, results)
507
508
509


if __name__ == "__main__":
510
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
511
512
513
514
515
516
517
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf", "mii"],
                        default="vllm")
    parser.add_argument("--dataset",
                        type=str,
                        default=None,
zhuwenwen's avatar
zhuwenwen committed
518
519
520
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
                        help="Maximum batch size for HF backend.")
zhuwenwen's avatar
zhuwenwen committed
546
547
548
549
550
551
552
553
554
555
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
556
557
558
559
560
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
zhuwenwen's avatar
zhuwenwen committed
561
562
563
564
565
566
567
568
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
569
570
571
572
573
574
575
576
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")

zhuwenwen's avatar
zhuwenwen committed
577
    parser = AsyncEngineArgs.add_cli_args(parser)
578
579
580
581
582
583
584
585
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
586
587
    if args.enable_lora:
        assert args.lora_path is not None
588
589
590
591
592
593
594
595
596

    if args.backend == "vllm":
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
597
598
599
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
600
601
602
603
604
605
606
607
608
609
610
611
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
612
613
614
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
615
    main(args)