benchmark_throughput.py 27 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Benchmark offline inference throughput."""
import argparse
zhuwenwen's avatar
zhuwenwen committed
4
import dataclasses
5
import json
zhuwenwen's avatar
zhuwenwen committed
6
import os
7
8
import random
import time
zhuwenwen's avatar
zhuwenwen committed
9

zhuwenwen's avatar
zhuwenwen committed
10
from pathlib import Path
zhuwenwen's avatar
zhuwenwen committed
11
12
import warnings
from typing import Any, Optional, Union
13
14
15

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
16
import uvloop
zhuwenwen's avatar
zhuwenwen committed
17
18
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
                               ConversationDataset, InstructCoderDataset,
zhuwenwen's avatar
zhuwenwen committed
19
20
21
                               RandomDataset, SampleRequest, ShareGPTDataset,
                               SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
22
23
24
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
zhuwenwen's avatar
zhuwenwen committed
25

zhuwenwen's avatar
zhuwenwen committed
26
27
28

from vllm.inputs import PromptType
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
29
30
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
zhuwenwen's avatar
zhuwenwen committed
31
from vllm.inputs import TextPrompt, TokensPrompt
32
from vllm.lora.request import LoRARequest
zhuwenwen's avatar
zhuwenwen committed
33
from vllm.outputs import RequestOutput
zhuwenwen's avatar
zhuwenwen committed
34
from vllm.sampling_params import BeamSearchParams
zhuwenwen's avatar
zhuwenwen committed
35
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
36
37
38


def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
39
    requests: list[SampleRequest],
40
    n: int,
41
    num_iters_warmup: int,
zhuwenwen's avatar
zhuwenwen committed
42
    engine_args: EngineArgs,
zhuwenwen's avatar
zhuwenwen committed
43
44
    disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
45
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
46
    llm = LLM(**dataclasses.asdict(engine_args))
zhuwenwen's avatar
zhuwenwen committed
47
48
49
50
51
52
    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests.")
53
    # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
54
55
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
    sampling_params: list[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
56
57
    for request in requests:
        prompts.append(
zhuwenwen's avatar
zhuwenwen committed
58
59
60
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                       multi_modal_data=request.multi_modal_data)
            if "prompt_token_ids" in request.prompt else \
zhuwenwen's avatar
zhuwenwen committed
61
62
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
63
64
65
        sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
66
                temperature=1.0,
67
68
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
69
                max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
70
                detokenize=not disable_detokenize,
71
            ))
zhuwenwen's avatar
zhuwenwen committed
72
    lora_requests: Optional[list[LoRARequest]] = None
73
74
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
75
76

    # warmup
77
78
79
80
81
82
83
84
85
86
87
    warmup_sampling_params = SamplingParams(
        n=args.n,
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=10,
    )
    dummy_prompt_token_ids = np.random.randint(10000, size=(1,10))
    dummy_prompts: List[PromptType] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
88
    
89
90
91
92
93
    print("Warming up...")
    for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
        llm.generate(dummy_prompts,
                        sampling_params=warmup_sampling_params,
                        use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
94

zhuwenwen's avatar
zhuwenwen committed
95
96
    use_beam_search = False

zhuwenwen's avatar
zhuwenwen committed
97
    outputs = None
zhuwenwen's avatar
zhuwenwen committed
98
    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if args.profile:
            profile_dir = args.profile_result_dir
            if not profile_dir:
                profile_dir = Path(
                    "."
                ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
            print(f"Profiling (results will be saved to '{profile_dir}')...")
            with torch.profiler.profile(
                        activities=[torch.profiler.ProfilerActivity.CPU,
                                    torch.profiler.ProfilerActivity.CUDA,
                        ],record_shapes=True,
                        on_trace_ready=torch.profiler.tensorboard_trace_handler(str(profile_dir))
                        ) as prof:
                start = time.perf_counter()
                llm.generate(prompts,
zhuwenwen's avatar
zhuwenwen committed
114
115
116
                             sampling_params,
                             lora_request=lora_requests,
                             use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
117
118
119
120
121
122
                end = time.perf_counter()
            print('Prepare time report')
            print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cuda_time_total", row_limit=-1))
        else:
            start = time.perf_counter()
            llm.generate(prompts,
zhuwenwen's avatar
zhuwenwen committed
123
124
125
                         sampling_params,
                         lora_request=lora_requests,
                         use_tqdm=True)
zhuwenwen's avatar
zhuwenwen committed
126
            end = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
127
    else:
128
        assert lora_requests is None, "BeamSearch API does not support LoRA"
zhuwenwen's avatar
zhuwenwen committed
129
        prompts = [request.prompt for request in requests]
zhuwenwen's avatar
zhuwenwen committed
130
131
        # output_len should be the same for all requests.
        output_len = requests[0][2]
zhuwenwen's avatar
zhuwenwen committed
132
133
        for request in requests:
            assert request.expected_output_len == output_len
zhuwenwen's avatar
zhuwenwen committed
134
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
135
136
137
138
139
140
141
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
zhuwenwen's avatar
zhuwenwen committed
142
        end = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    return end - start, outputs


def run_vllm_chat(
        requests: list[SampleRequest],
        n: int,
        engine_args: EngineArgs,
        disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of "
            "prompt_len and expected_output_len for all requests.")

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
            ))
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
183
184


zhuwenwen's avatar
zhuwenwen committed
185
async def run_vllm_async(
zhuwenwen's avatar
zhuwenwen committed
186
    requests: list[SampleRequest],
zhuwenwen's avatar
zhuwenwen committed
187
    n: int,
zhuwenwen's avatar
zhuwenwen committed
188
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
189
    disable_frontend_multiprocessing: bool = False,
zhuwenwen's avatar
zhuwenwen committed
190
    disable_detokenize: bool = False,
zhuwenwen's avatar
zhuwenwen committed
191
192
193
194
195
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:
zhuwenwen's avatar
zhuwenwen committed
196
197
198
199
200
201
        assert all(
            llm.model_config.max_model_len >= (request.prompt_len +
                                               request.expected_output_len)
            for request in requests), (
                "Please ensure that max_model_len is greater than the sum of"
                " prompt_len and expected_output_len for all requests.")
zhuwenwen's avatar
zhuwenwen committed
202
203

        # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
204
205
206
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
zhuwenwen's avatar
zhuwenwen committed
207
208
        for request in requests:
            prompts.append(
zhuwenwen's avatar
zhuwenwen committed
209
210
211
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                        multi_modal_data=request.multi_modal_data)
                if "prompt_token_ids" in request.prompt else \
zhuwenwen's avatar
zhuwenwen committed
212
213
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
zhuwenwen's avatar
zhuwenwen committed
214
215
216
            sampling_params.append(
                SamplingParams(
                    n=n,
zhuwenwen's avatar
zhuwenwen committed
217
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
218
219
                    top_p=1.0,
                    ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
220
                    max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
221
                    detokenize=not disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
222
                ))
223
            lora_requests.append(request.lora_request)
zhuwenwen's avatar
zhuwenwen committed
224
225
226

        generators = []
        start = time.perf_counter()
227
228
229
230
231
232
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
zhuwenwen's avatar
zhuwenwen committed
233
234
235
236
237
238
239
240
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


241
def run_hf(
zhuwenwen's avatar
zhuwenwen committed
242
    requests: list[SampleRequest],
243
244
245
246
247
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
zhuwenwen's avatar
zhuwenwen committed
248
    disable_detokenize: bool = False,
249
250
251
252
253
254
255
256
257
258
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
259
    batch: list[str] = []
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
zhuwenwen's avatar
zhuwenwen committed
281
            do_sample=True,
282
283
284
285
286
287
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
zhuwenwen's avatar
zhuwenwen committed
288
289
290
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
291
292
293
294
295
296
297
298
299
300
301
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
zhuwenwen's avatar
zhuwenwen committed
302
    requests: list[SampleRequest],
303
304
305
306
307
308
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
zhuwenwen's avatar
zhuwenwen committed
309
    prompts = [request.prompt for request in requests]
310
311
312
313
314
315
316
317
318

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


zhuwenwen's avatar
zhuwenwen committed
319
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
zhuwenwen's avatar
zhuwenwen committed
320
                                     results: dict[str, Any]) -> None:
zhuwenwen's avatar
zhuwenwen committed
321
322
323
324
325
326
327
328
329
330
331
332
333
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
zhuwenwen's avatar
zhuwenwen committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        write_to_json(pt_file, pt_records)


def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
zhuwenwen's avatar
zhuwenwen committed
351

zhuwenwen's avatar
zhuwenwen committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
            "Tokenizer/model must have chat template for sonnet dataset.")
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
    elif args.dataset_name == "hf":
zhuwenwen's avatar
zhuwenwen committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
            common_kwargs['dataset_subset'] = None
            common_kwargs['dataset_split'] = "train"
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = InstructCoderDataset
            common_kwargs['dataset_split'] = "train"
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
            common_kwargs['dataset_subset'] = args.hf_subset
            common_kwargs['dataset_split'] = args.hf_split
            sample_kwargs["enable_multimodal_chat"] = True
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = AIMODataset
            common_kwargs['dataset_subset'] = None
            common_kwargs['dataset_split'] = "train"
zhuwenwen's avatar
zhuwenwen committed
386
387
388
389
390
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)
zhuwenwen's avatar
zhuwenwen committed
391
392


393
def main(args: argparse.Namespace):
zhuwenwen's avatar
zhuwenwen committed
394
395
    if args.seed is None:
        args.seed = 0
396
397
398
399
400
    print(args)
    random.seed(args.seed)
    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
zhuwenwen's avatar
zhuwenwen committed
401
    requests = get_requests(args, tokenizer)
zhuwenwen's avatar
zhuwenwen committed
402
403
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
zhuwenwen's avatar
zhuwenwen committed
404
    request_outputs: Optional[list[RequestOutput]] = None
405
    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
406
        if args.async_engine:
zhuwenwen's avatar
zhuwenwen committed
407
408
409
410
411
412
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
zhuwenwen's avatar
zhuwenwen committed
413
                    args.disable_detokenize,
zhuwenwen's avatar
zhuwenwen committed
414
                ))
zhuwenwen's avatar
zhuwenwen committed
415
        else:
zhuwenwen's avatar
zhuwenwen committed
416
417
418
            elapsed_time, request_outputs = run_vllm(
                requests, args.n, args.num_iters_warmup, EngineArgs.from_cli_args(args),
                args.disable_detokenize)
419
420
421
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
zhuwenwen's avatar
zhuwenwen committed
422
423
                              args.hf_max_batch_size, args.trust_remote_code,
                              args.disable_detokenize)
424
425
426
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
zhuwenwen's avatar
zhuwenwen committed
427
428
429
430
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
            requests, args.n, EngineArgs.from_cli_args(args),
            args.disable_detokenize)
431
432
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
zhuwenwen's avatar
zhuwenwen committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
            total_prompt_tokens += len(
                ro.prompt_token_ids) if ro.prompt_token_ids else 0
            total_output_tokens += sum(
                len(o.token_ids) for o in ro.outputs if o)
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
        total_num_tokens = sum(r.prompt_len + r.expected_output_len
                               for r in requests)
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
        print("\033[91mWARNING\033[0m: Multi-modal request with "
              f"{args.backend} backend detected. The "
zhuwenwen's avatar
zhuwenwen committed
456
457
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
zhuwenwen's avatar
zhuwenwen committed
458
459
460
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
        # vllm-chat backend counts the image tokens now

zhuwenwen's avatar
zhuwenwen committed
461
    print(f"Latency: {elapsed_time:.2f} s")
zhuwenwen's avatar
zhuwenwen committed
462
463
464
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
zhuwenwen's avatar
zhuwenwen committed
465
466
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
467
468
469
470
471
472
473
474
475
476
477
478

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
zhuwenwen's avatar
zhuwenwen committed
479
        save_to_pytorch_benchmark_format(args, results)
480
481


zhuwenwen's avatar
zhuwenwen committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
            stacklevel=2)
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
        print(
            "When dataset path is not set, it will default to random dataset")
        args.dataset_name = 'random'
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
            getattr(args, "hf_subset", None) is not None
            or getattr(args, "hf_split", None) is not None):
        warnings.warn("--hf-subset and --hf-split will be ignored \
                since --dataset-name is not 'hf'.",
                      stacklevel=2)
zhuwenwen's avatar
zhuwenwen committed
520
521
522
523
524
525
526
527
528
529
530
    elif args.dataset_name == "hf":
        if args.dataset_path in (
                VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
                | ConversationDataset.SUPPORTED_DATASET_PATHS):
            assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend."  #noqa: E501
        elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
                                   | AIMODataset.SUPPORTED_DATASET_PATHS):
            assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend."  #noqa: E501
        else:
            raise ValueError(
                f"{args.dataset_path} is not supported by hf dataset.")
zhuwenwen's avatar
zhuwenwen committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

    # --random-range-ratio: only used when dataset_name is 'random'
    if args.dataset_name != 'random' and args.random_range_ratio is not None:
        warnings.warn("--random-range-ratio will be ignored since \
                --dataset-name is not 'random'.",
                      stacklevel=2)

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
    if args.dataset_name not in {"random", "sonnet", None
                                 } and args.prefix_len is not None:
        warnings.warn("--prefix-len will be ignored since --dataset-name\
                 is not 'random', 'sonnet', or not set.",
                      stacklevel=2)

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
        raise ValueError(
            "LoRA benchmarking is only supported for vLLM backend")
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

    if args.backend in {"hf", "mii"} and getattr(args, "quantization",
                                                 None) is not None:
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
        raise ValueError(
            "Tokenizer must be the same as the model for MII backend.")


572
if __name__ == "__main__":
573
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
574
575
    parser.add_argument("--backend",
                        type=str,
zhuwenwen's avatar
zhuwenwen committed
576
                        choices=["vllm", "hf", "mii", "vllm-chat"],
577
                        default="vllm")
zhuwenwen's avatar
zhuwenwen committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
        default="sharegpt")
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
        "list[dict[..., value: <prompt_or_response>]]]]")
    parser.add_argument("--dataset-path",
593
594
                        type=str,
                        default=None,
zhuwenwen's avatar
zhuwenwen committed
595
                        help="Path to the dataset")
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
                        help="Maximum batch size for HF backend.")
zhuwenwen's avatar
zhuwenwen committed
621
622
623
624
625
626
627
628
629
630
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
631
632
633
634
635
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
zhuwenwen's avatar
zhuwenwen committed
636
637
638
639
640
641
642
643
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
zhuwenwen's avatar
zhuwenwen committed
644
645
646
647
648
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=("Do not detokenize the response (i.e. do not include "
              "detokenization time in the measurement)"))
649
650
651
652
653
654
655
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")
zhuwenwen's avatar
zhuwenwen committed
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    parser.add_argument("--prefix-len",
                        type=int,
                        default=None,
                        help="Number of prefix tokens per request."
                        "This is for the RandomDataset and SonnetDataset")
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
        default=None,
        help="Range of sampled ratio of input/output length, "
        "used only for RandomDataSet.",
    )

    # hf dtaset
    parser.add_argument("--hf-subset",
                        type=str,
                        default=None,
                        help="Subset of the HF dataset.")
    parser.add_argument("--hf-split",
                        type=str,
                        default=None,
                        help="Split of the HF dataset.")
679

zhuwenwen's avatar
zhuwenwen committed
680
    parser = AsyncEngineArgs.add_cli_args(parser)
681
682
683
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
zhuwenwen's avatar
zhuwenwen committed
684
    validate_args(args)
685
    main(args)