benchmark_throughput.py 28.3 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
# SPDX-License-Identifier: Apache-2.0
zhuwenwen's avatar
zhuwenwen committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Benchmark offline inference throughput."""
zhuwenwen's avatar
zhuwenwen committed
4

5
import argparse
zhuwenwen's avatar
zhuwenwen committed
6
import dataclasses
7
import json
zhuwenwen's avatar
zhuwenwen committed
8
import os
9
10
import random
import time
zhuwenwen's avatar
zhuwenwen committed
11

zhuwenwen's avatar
zhuwenwen committed
12
from pathlib import Path
zhuwenwen's avatar
zhuwenwen committed
13
14
import warnings
from typing import Any, Optional, Union
15
16
17

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
18
import uvloop
19
from tqdm import tqdm
zhuwenwen's avatar
zhuwenwen committed
20
21

from vllm.inputs import PromptType
zhuwenwen's avatar
zhuwenwen committed
22
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
23
from typing_extensions import deprecated
zhuwenwen's avatar
zhuwenwen committed
24
25
26
27
28
29
30
31
32
33
34
35
36

from benchmark_dataset import (
    AIMODataset,
    BurstGPTDataset,
    ConversationDataset,
    InstructCoderDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
zhuwenwen's avatar
zhuwenwen committed
37
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
38
from vllm.entrypoints.openai.api_server import (
zhuwenwen's avatar
zhuwenwen committed
39
40
    build_async_engine_client_from_engine_args,
)
zhuwenwen's avatar
zhuwenwen committed
41
from vllm.inputs import TextPrompt, TokensPrompt
42
from vllm.lora.request import LoRARequest
zhuwenwen's avatar
zhuwenwen committed
43
from vllm.outputs import RequestOutput
zhuwenwen's avatar
zhuwenwen committed
44
from vllm.sampling_params import BeamSearchParams
zhuwenwen's avatar
zhuwenwen committed
45
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
46
47
48


def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
49
    requests: list[SampleRequest],
50
    n: int,
51
    num_iters_warmup: int,
zhuwenwen's avatar
zhuwenwen committed
52
    engine_args: EngineArgs,
zhuwenwen's avatar
zhuwenwen committed
53
54
    disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
55
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
56

zhuwenwen's avatar
zhuwenwen committed
57
    llm = LLM(**dataclasses.asdict(engine_args))
zhuwenwen's avatar
zhuwenwen committed
58
    assert all(
zhuwenwen's avatar
zhuwenwen committed
59
60
61
62
63
64
65
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of"
        " prompt_len and expected_output_len for all requests."
    )
66
    # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
67
68
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
    sampling_params: list[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
69
70
    for request in requests:
        prompts.append(
zhuwenwen's avatar
zhuwenwen committed
71
72
73
74
75
76
77
78
79
            TokensPrompt(
                prompt_token_ids=request.prompt["prompt_token_ids"],
                multi_modal_data=request.multi_modal_data,
            )
            if "prompt_token_ids" in request.prompt
            else TextPrompt(
                prompt=request.prompt, multi_modal_data=request.multi_modal_data
            )
        )
80
81
82
        sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
83
                temperature=1.0,
84
85
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
86
                max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
87
                detokenize=not disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
88
89
            )
        )
zhuwenwen's avatar
zhuwenwen committed
90
    lora_requests: Optional[list[LoRARequest]] = None
91
92
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
93
94

    # warmup
95
    warmup_sampling_params = SamplingParams(
zhuwenwen's avatar
zhuwenwen committed
96
        n=n,
97
98
99
100
101
102
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=10,
    )
    dummy_prompt_token_ids = np.random.randint(10000, size=(1,10))
zhuwenwen's avatar
zhuwenwen committed
103
    dummy_prompts: list[PromptType] = [{
104
105
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
106
    
zhuwenwen's avatar
zhuwenwen committed
107
108
    use_beam_search = False
    
109
110
    print("Warming up...")
    for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
zhuwenwen's avatar
zhuwenwen committed
111
        if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
112
            llm.generate(dummy_prompts, sampling_params=warmup_sampling_params, use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
113
114
115
116
117
118
119
120
121
        else:
            llm.beam_search(
                dummy_prompts,
                BeamSearchParams(
                    beam_width=args.n,
                    max_tokens=args.output_len,
                    ignore_eos=True,
                ),
            )
zhuwenwen's avatar
zhuwenwen committed
122

zhuwenwen's avatar
zhuwenwen committed
123
    outputs = None
zhuwenwen's avatar
zhuwenwen committed
124
    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if args.profile:
            profile_dir = args.profile_result_dir
            if not profile_dir:
                profile_dir = Path(
                    "."
                ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
            print(f"Profiling (results will be saved to '{profile_dir}')...")
            with torch.profiler.profile(
                        activities=[torch.profiler.ProfilerActivity.CPU,
                                    torch.profiler.ProfilerActivity.CUDA,
                        ],record_shapes=True,
                        on_trace_ready=torch.profiler.tensorboard_trace_handler(str(profile_dir))
                        ) as prof:
                start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
139
140
141
                outputs = llm.generate(
                    prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
                )
zhuwenwen's avatar
zhuwenwen committed
142
143
144
145
146
                end = time.perf_counter()
            print('Prepare time report')
            print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cuda_time_total", row_limit=-1))
        else:
            start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
147
148
149
            outputs = llm.generate(
                prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
            )
zhuwenwen's avatar
zhuwenwen committed
150
            end = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
151
    else:
152
        assert lora_requests is None, "BeamSearch API does not support LoRA"
zhuwenwen's avatar
zhuwenwen committed
153
        # output_len should be the same for all requests.
zhuwenwen's avatar
zhuwenwen committed
154
        output_len = requests[0].expected_output_len
zhuwenwen's avatar
zhuwenwen committed
155
156
        for request in requests:
            assert request.expected_output_len == output_len
zhuwenwen's avatar
zhuwenwen committed
157
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
158
159
160
161
162
163
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
164
165
            ),
        )
zhuwenwen's avatar
zhuwenwen committed
166
        end = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
167
168
169
170
    return end - start, outputs


def run_vllm_chat(
zhuwenwen's avatar
zhuwenwen committed
171
172
173
174
175
    requests: list[SampleRequest],
    n: int,
    engine_args: EngineArgs,
    disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
zhuwenwen's avatar
zhuwenwen committed
176
177
178
179
180
181
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
182

zhuwenwen's avatar
zhuwenwen committed
183
184
185
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
zhuwenwen's avatar
zhuwenwen committed
186
187
188
189
190
191
192
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )
zhuwenwen's avatar
zhuwenwen committed
193
194
195
196
197
198
199
200
201
202
203
204
205

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
206
207
            )
        )
zhuwenwen's avatar
zhuwenwen committed
208
209
210
211
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
212
213


zhuwenwen's avatar
zhuwenwen committed
214
async def run_vllm_async(
zhuwenwen's avatar
zhuwenwen committed
215
    requests: list[SampleRequest],
zhuwenwen's avatar
zhuwenwen committed
216
    n: int,
zhuwenwen's avatar
zhuwenwen committed
217
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
218
    disable_frontend_multiprocessing: bool = False,
zhuwenwen's avatar
zhuwenwen committed
219
    disable_detokenize: bool = False,
zhuwenwen's avatar
zhuwenwen committed
220
221
222
223
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
224
225
        engine_args,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
zhuwenwen's avatar
zhuwenwen committed
226
227
    ) as llm:
        model_config = await llm.get_model_config()
zhuwenwen's avatar
zhuwenwen committed
228
        assert all(
zhuwenwen's avatar
zhuwenwen committed
229
230
231
232
233
234
235
            model_config.max_model_len
            >= (request.prompt_len + request.expected_output_len)
            for request in requests
        ), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests."
        )
zhuwenwen's avatar
zhuwenwen committed
236
237

        # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
238
239
240
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
zhuwenwen's avatar
zhuwenwen committed
241
242
        for request in requests:
            prompts.append(
zhuwenwen's avatar
zhuwenwen committed
243
244
245
246
247
248
249
250
251
                TokensPrompt(
                    prompt_token_ids=request.prompt["prompt_token_ids"],
                    multi_modal_data=request.multi_modal_data,
                )
                if "prompt_token_ids" in request.prompt
                else TextPrompt(
                    prompt=request.prompt, multi_modal_data=request.multi_modal_data
                )
            )
zhuwenwen's avatar
zhuwenwen committed
252
253
254
            sampling_params.append(
                SamplingParams(
                    n=n,
zhuwenwen's avatar
zhuwenwen committed
255
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
256
257
                    top_p=1.0,
                    ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
258
                    max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
259
                    detokenize=not disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
260
261
                )
            )
262
            lora_requests.append(request.lora_request)
zhuwenwen's avatar
zhuwenwen committed
263
264
265

        generators = []
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
266
267
268
269
        for i, (prompt, sp, lr) in enumerate(
            zip(prompts, sampling_params, lora_requests)
        ):
            generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
zhuwenwen's avatar
zhuwenwen committed
270
271
272
273
274
275
276
277
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


278
def run_hf(
zhuwenwen's avatar
zhuwenwen committed
279
    requests: list[SampleRequest],
280
281
282
283
284
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
zhuwenwen's avatar
zhuwenwen committed
285
    disable_detokenize: bool = False,
286
287
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
zhuwenwen's avatar
zhuwenwen committed
288
289
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
    )
290
291
292
293
294
295
296
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
297
    batch: list[str] = []
298
299
300
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
zhuwenwen's avatar
zhuwenwen committed
301
302
303
        prompt = requests[i].prompt
        prompt_len = requests[i].prompt_len
        output_len = requests[i].expected_output_len
304
305
306
307
308
309
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
zhuwenwen's avatar
zhuwenwen committed
310
311
312
313
314
315
            next_prompt_len = requests[i + 1].prompt_len
            next_output_len = requests[i + 1].expected_output_len
            if (
                max(max_prompt_len, next_prompt_len)
                + max(max_output_len, next_output_len)
            ) <= 2048:
316
317
318
319
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
zhuwenwen's avatar
zhuwenwen committed
320
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
321
322
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
zhuwenwen's avatar
zhuwenwen committed
323
            do_sample=True,
324
325
326
327
328
329
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
zhuwenwen's avatar
zhuwenwen committed
330
331
332
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
333
334
335
336
337
338
339
340
341
342
343
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
zhuwenwen's avatar
zhuwenwen committed
344
    requests: list[SampleRequest],
345
346
347
348
349
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
zhuwenwen's avatar
zhuwenwen committed
350

351
    llm = serve(model, tensor_parallel=tensor_parallel_size)
zhuwenwen's avatar
zhuwenwen committed
352
    prompts = [request.prompt for request in requests]
353
354
355
356
357
358
359
360
361

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


zhuwenwen's avatar
zhuwenwen committed
362
363
364
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any]
) -> None:
zhuwenwen's avatar
zhuwenwen committed
365
366
367
368
369
370
371
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
zhuwenwen's avatar
zhuwenwen committed
372
373
374
            k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        },
    )
zhuwenwen's avatar
zhuwenwen committed
375
376
377
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
zhuwenwen's avatar
zhuwenwen committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        write_to_json(pt_file, pt_records)


def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
zhuwenwen's avatar
zhuwenwen committed
395

zhuwenwen's avatar
zhuwenwen committed
396
397
398
399
400
401
402
403
404
405
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
zhuwenwen's avatar
zhuwenwen committed
406
407
            "Tokenizer/model must have chat template for sonnet dataset."
        )
zhuwenwen's avatar
zhuwenwen committed
408
409
410
411
412
413
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
    elif args.dataset_name == "hf":
414
        common_kwargs["no_stream"] = args.no_stream
zhuwenwen's avatar
zhuwenwen committed
415
416
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
zhuwenwen's avatar
zhuwenwen committed
417
418
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
zhuwenwen's avatar
zhuwenwen committed
419
420
421
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = InstructCoderDataset
zhuwenwen's avatar
zhuwenwen committed
422
            common_kwargs["dataset_split"] = "train"
zhuwenwen's avatar
zhuwenwen committed
423
424
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
zhuwenwen's avatar
zhuwenwen committed
425
426
            common_kwargs["dataset_subset"] = args.hf_subset
            common_kwargs["dataset_split"] = args.hf_split
zhuwenwen's avatar
zhuwenwen committed
427
428
429
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
zhuwenwen's avatar
zhuwenwen committed
430
431
            common_kwargs["dataset_subset"] = None
            common_kwargs["dataset_split"] = "train"
zhuwenwen's avatar
zhuwenwen committed
432
433
434
435
436
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)
zhuwenwen's avatar
zhuwenwen committed
437
438


439
440
441
442
@deprecated(
    "benchmark_throughput.py is deprecated and will be removed in a "
    "future version. Please use 'vllm bench throughput' instead.",
)
443
def main(args: argparse.Namespace):
zhuwenwen's avatar
zhuwenwen committed
444
445
    if args.seed is None:
        args.seed = 0
446
447
448
449
    print(args)
    random.seed(args.seed)
    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
zhuwenwen's avatar
zhuwenwen committed
450
451
        args.tokenizer, trust_remote_code=args.trust_remote_code
    )
zhuwenwen's avatar
zhuwenwen committed
452
    requests = get_requests(args, tokenizer)
zhuwenwen's avatar
zhuwenwen committed
453
    is_multi_modal = any(request.multi_modal_data is not None for request in requests)
zhuwenwen's avatar
zhuwenwen committed
454
    request_outputs: Optional[list[RequestOutput]] = None
455
    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
456
        if args.async_engine:
zhuwenwen's avatar
zhuwenwen committed
457
458
459
460
461
462
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
zhuwenwen's avatar
zhuwenwen committed
463
                    args.disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
464
465
                )
            )
zhuwenwen's avatar
zhuwenwen committed
466
        else:
zhuwenwen's avatar
zhuwenwen committed
467
            elapsed_time, request_outputs = run_vllm(
zhuwenwen's avatar
zhuwenwen committed
468
469
470
471
472
473
                requests,
                args.n,
                args.num_iters_warmup,
                EngineArgs.from_cli_args(args),
                args.disable_detokenize,
            )
474
475
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
zhuwenwen's avatar
zhuwenwen committed
476
477
478
479
480
481
482
483
484
        elapsed_time = run_hf(
            requests,
            args.model,
            tokenizer,
            args.n,
            args.hf_max_batch_size,
            args.trust_remote_code,
            args.disable_detokenize,
        )
485
    elif args.backend == "mii":
zhuwenwen's avatar
zhuwenwen committed
486
487
488
        elapsed_time = run_mii(
            requests, args.model, args.tensor_parallel_size, args.output_len
        )
zhuwenwen's avatar
zhuwenwen committed
489
490
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
zhuwenwen's avatar
zhuwenwen committed
491
492
            requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
        )
493
494
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
zhuwenwen's avatar
zhuwenwen committed
495
496
497
498
499
500
501
502
503

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
zhuwenwen's avatar
zhuwenwen committed
504
505
506
507
            total_prompt_tokens += (
                len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
            )
            total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
zhuwenwen's avatar
zhuwenwen committed
508
509
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
zhuwenwen's avatar
zhuwenwen committed
510
        total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
zhuwenwen's avatar
zhuwenwen committed
511
512
513
514
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
zhuwenwen's avatar
zhuwenwen committed
515
516
517
518
519
520
        print(
            "\033[91mWARNING\033[0m: Multi-modal request with "
            f"{args.backend} backend detected. The "
            "following metrics are not accurate because image tokens are not"
            " counted. See vllm-project/vllm/issues/9778 for details."
        )
zhuwenwen's avatar
zhuwenwen committed
521
522
523
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
        # vllm-chat backend counts the image tokens now

zhuwenwen's avatar
zhuwenwen committed
524
    print(f"Latency: {elapsed_time:.2f} s")
zhuwenwen's avatar
zhuwenwen committed
525
526
527
528
529
    print(
        f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
        f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
        f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
    )
zhuwenwen's avatar
zhuwenwen committed
530
531
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
532
533
534
535
536
537
538
539
540
541
542
543

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
zhuwenwen's avatar
zhuwenwen committed
544
        save_to_pytorch_benchmark_format(args, results)
545
546


zhuwenwen's avatar
zhuwenwen committed
547
548
549
550
551
552
553
554
555
556
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
zhuwenwen's avatar
zhuwenwen committed
557
558
            stacklevel=2,
        )
zhuwenwen's avatar
zhuwenwen committed
559
560
561
562
563
564
565
566
567
568
569
570
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
zhuwenwen's avatar
zhuwenwen committed
571
572
        print("When dataset path is not set, it will default to random dataset")
        args.dataset_name = "random"
zhuwenwen's avatar
zhuwenwen committed
573
574
575
576
577
578
579
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
zhuwenwen's avatar
zhuwenwen committed
580
581
582
583
584
        getattr(args, "hf_subset", None) is not None
        or getattr(args, "hf_split", None) is not None
    ):
        warnings.warn(
            "--hf-subset and --hf-split will be ignored \
zhuwenwen's avatar
zhuwenwen committed
585
                since --dataset-name is not 'hf'.",
zhuwenwen's avatar
zhuwenwen committed
586
587
            stacklevel=2,
        )
zhuwenwen's avatar
zhuwenwen committed
588
589
    elif args.dataset_name == "hf":
        if args.dataset_path in (
zhuwenwen's avatar
zhuwenwen committed
590
591
592
593
594
595
596
597
598
599
600
601
602
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
            | ConversationDataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm-chat", (
                f"{args.dataset_path} needs to use vllm-chat as the backend."
            )  # noqa: E501
        elif args.dataset_path in (
            InstructCoderDataset.SUPPORTED_DATASET_PATHS
            | AIMODataset.SUPPORTED_DATASET_PATHS
        ):
            assert args.backend == "vllm", (
                f"{args.dataset_path} needs to use vllm as the backend."
            )  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
603
        else:
zhuwenwen's avatar
zhuwenwen committed
604
            raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
zhuwenwen's avatar
zhuwenwen committed
605
606

    # --random-range-ratio: only used when dataset_name is 'random'
zhuwenwen's avatar
zhuwenwen committed
607
608
609
    if args.dataset_name != "random" and args.random_range_ratio is not None:
        warnings.warn(
            "--random-range-ratio will be ignored since \
zhuwenwen's avatar
zhuwenwen committed
610
                --dataset-name is not 'random'.",
zhuwenwen's avatar
zhuwenwen committed
611
612
            stacklevel=2,
        )
zhuwenwen's avatar
zhuwenwen committed
613
614
615

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
zhuwenwen's avatar
zhuwenwen committed
616
617
618
619
620
621
    if (
        args.dataset_name not in {"random", "sonnet", None}
        and args.prefix_len is not None
    ):
        warnings.warn(
            "--prefix-len will be ignored since --dataset-name\
zhuwenwen's avatar
zhuwenwen committed
622
                 is not 'random', 'sonnet', or not set.",
zhuwenwen's avatar
zhuwenwen committed
623
624
            stacklevel=2,
        )
zhuwenwen's avatar
zhuwenwen committed
625
626
627

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
zhuwenwen's avatar
zhuwenwen committed
628
        raise ValueError("LoRA benchmarking is only supported for vLLM backend")
zhuwenwen's avatar
zhuwenwen committed
629
630
631
632
633
634
635
636
637
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

zhuwenwen's avatar
zhuwenwen committed
638
639
640
641
    if (
        args.backend in {"hf", "mii"}
        and getattr(args, "quantization", None) is not None
    ):
zhuwenwen's avatar
zhuwenwen committed
642
643
644
645
646
647
648
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
zhuwenwen's avatar
zhuwenwen committed
649
650
651
652
653
        raise ValueError("Tokenizer must be the same as the model for MII backend.")

    # --data-parallel is not supported currently.
    # https://github.com/vllm-project/vllm/issues/16222
    if args.data_parallel_size > 1:
zhuwenwen's avatar
zhuwenwen committed
654
        raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
655
656
            "Data parallel is not supported in offline benchmark, "
            "please use benchmark serving instead"
zhuwenwen's avatar
zhuwenwen committed
657
        )
zhuwenwen's avatar
zhuwenwen committed
658
659


zhuwenwen's avatar
zhuwenwen committed
660
def create_argument_parser():
661
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
zhuwenwen's avatar
zhuwenwen committed
662
663
664
665
666
667
    parser.add_argument(
        "--backend",
        type=str,
        choices=["vllm", "hf", "mii", "vllm-chat"],
        default="vllm",
    )
zhuwenwen's avatar
zhuwenwen committed
668
669
670
671
672
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
zhuwenwen's avatar
zhuwenwen committed
673
674
        default="sharegpt",
    )
675
676
677
678
679
    parser.add_argument(
        "--no-stream",
        action="store_true",
        help="Do not load the dataset in streaming mode.",
    )
zhuwenwen's avatar
zhuwenwen committed
680
681
682
683
684
685
686
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
zhuwenwen's avatar
zhuwenwen committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        "list[dict[..., value: <prompt_or_response>]]]]",
    )
    parser.add_argument(
        "--dataset-path", type=str, default=None, help="Path to the dataset"
    )
    parser.add_argument(
        "--input-len",
        type=int,
        default=None,
        help="Input prompt length for each request",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the "
        "output length from the dataset.",
    )
    parser.add_argument(
        "--n", type=int, default=1, help="Number of generated sequences per prompt."
    )
    parser.add_argument(
        "--num-iters-warmup", type=int, default=1, help="Number of iterations to run for warmup."
    )
    parser.add_argument(
        "--num-prompts", type=int, default=1000, help="Number of prompts to process."
    )
zhuwenwen's avatar
zhuwenwen committed
714
715
716
717
718
719
720
721
722
723
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
724
    parser.add_argument(
zhuwenwen's avatar
zhuwenwen committed
725
726
727
728
729
730
731
        "--hf-max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for HF backend.",
    )
    parser.add_argument(
        "--output-json",
732
733
        type=str,
        default=None,
zhuwenwen's avatar
zhuwenwen committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        help="Path to save the throughput results in JSON format.",
    )
    parser.add_argument(
        "--async-engine",
        action="store_true",
        default=False,
        help="Use vLLM async engine rather than LLM class.",
    )
    parser.add_argument(
        "--disable-frontend-multiprocessing",
        action="store_true",
        default=False,
        help="Disable decoupled async engine frontend.",
    )
zhuwenwen's avatar
zhuwenwen committed
748
749
750
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
zhuwenwen's avatar
zhuwenwen committed
751
752
753
754
755
        help=(
            "Do not detokenize the response (i.e. do not include "
            "detokenization time in the measurement)"
        ),
    )
756
757
758
759
760
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
zhuwenwen's avatar
zhuwenwen committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        help="Path to the LoRA adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.",
    )
    parser.add_argument(
        "--prefix-len",
        type=int,
        default=None,
        help=f"Number of prefix tokens to be used in RandomDataset "
        "and SonnetDataset. For RandomDataset, the total input "
        "length is the sum of prefix-len (default: "
        f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
        "sampled from [input_len * (1 - range_ratio), "
        "input_len * (1 + range_ratio)]. For SonnetDataset, "
        f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
        "controls how much of the input is fixed lines versus "
        "random lines, but the total input length remains approximately "
        "input_len tokens.",
    )
zhuwenwen's avatar
zhuwenwen committed
779
780
781
782
783
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
        default=None,
zhuwenwen's avatar
zhuwenwen committed
784
785
786
787
788
        help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
        "for sampling input/output length, "
        "used only for RandomDataset. Must be in the range [0, 1) to "
        "define a symmetric sampling range "
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
zhuwenwen's avatar
zhuwenwen committed
789
790
791
    )

    # hf dtaset
zhuwenwen's avatar
zhuwenwen committed
792
793
794
795
796
797
    parser.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    parser.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
798

zhuwenwen's avatar
zhuwenwen committed
799
    parser = AsyncEngineArgs.add_cli_args(parser)
zhuwenwen's avatar
zhuwenwen committed
800
801
802
803
804
805

    return parser


if __name__ == "__main__":
    parser = create_argument_parser()
806
807
808
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
zhuwenwen's avatar
zhuwenwen committed
809
    validate_args(args)
810
    main(args)