"docs/vscode:/vscode.git/clone" did not exist on "9fc983c707a9371951135d2e5a62183da7e09369"
benchmark_throughput.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark offline inference throughput."""
3
import argparse
4
import dataclasses
5
import json
6
import os
7
8
import random
import time
9
import warnings
10
from typing import Any, Optional, Union
11

12
import torch
13
import uvloop
14
15
from benchmark_dataset import (BurstGPTDataset, RandomDataset, SampleRequest,
                               ShareGPTDataset, SonnetDataset)
16
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
17
from tqdm import tqdm
18
19
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
20

21
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
22
23
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
24
from vllm.inputs import TextPrompt, TokensPrompt
25
from vllm.lora.request import LoRARequest
26
from vllm.sampling_params import BeamSearchParams
27
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
28

29

Woosuk Kwon's avatar
Woosuk Kwon committed
30
def run_vllm(
31
    requests: list[SampleRequest],
32
    n: int,
33
    engine_args: EngineArgs,
34
    disable_detokenize: bool = False,
35
) -> float:
36
    from vllm import LLM, SamplingParams
37
    llm = LLM(**dataclasses.asdict(engine_args))
38
39
40
41
42
43
    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests.")
Zhuohan Li's avatar
Zhuohan Li committed
44
    # Add the requests to the engine.
45
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
46
    sampling_params: list[SamplingParams] = []
47
    for request in requests:
48
        prompts.append(
49
50
51
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                       multi_modal_data=request.multi_modal_data)
            if "prompt_token_ids" in request.prompt else \
52
53
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
54
55
56
        sampling_params.append(
            SamplingParams(
                n=n,
57
                temperature=1.0,
58
59
                top_p=1.0,
                ignore_eos=True,
60
                max_tokens=request.expected_output_len,
61
                detokenize=not disable_detokenize,
62
            ))
63
    lora_requests: Optional[list[LoRARequest]] = None
64
65
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
66

67
68
69
    use_beam_search = False

    if not use_beam_search:
70
        start = time.perf_counter()
71
72
73
74
        llm.generate(prompts,
                     sampling_params,
                     lora_request=lora_requests,
                     use_tqdm=True)
75
76
        end = time.perf_counter()
    else:
77
        assert lora_requests is None, "BeamSearch API does not support LoRA"
78
        prompts = [request.prompt for request in requests]
79
80
        # output_len should be the same for all requests.
        output_len = requests[0][2]
81
82
        for request in requests:
            assert request.expected_output_len == output_len
83
        start = time.perf_counter()
84
85
86
87
88
89
90
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
91
        end = time.perf_counter()
92
93
94
    return end - start


95
async def run_vllm_async(
96
    requests: list[SampleRequest],
97
    n: int,
98
    engine_args: AsyncEngineArgs,
99
    disable_frontend_multiprocessing: bool = False,
100
    disable_detokenize: bool = False,
101
102
103
104
105
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:
106
107
108
109
110
111
        assert all(
            llm.model_config.max_model_len >= (request.prompt_len +
                                               request.expected_output_len)
            for request in requests), (
                "Please ensure that max_model_len is greater than the sum of"
                " prompt_len and expected_output_len for all requests.")
112
113

        # Add the requests to the engine.
114
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
115
116
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
117
        for request in requests:
118
            prompts.append(
119
120
121
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                        multi_modal_data=request.multi_modal_data)
                if "prompt_token_ids" in request.prompt else \
122
123
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
124
125
126
            sampling_params.append(
                SamplingParams(
                    n=n,
127
                    temperature=1.0,
128
129
                    top_p=1.0,
                    ignore_eos=True,
130
                    max_tokens=request.expected_output_len,
131
                    detokenize=not disable_detokenize,
132
                ))
133
            lora_requests.append(request.lora_request)
134
135
136

        generators = []
        start = time.perf_counter()
137
138
139
140
141
142
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
143
144
145
146
147
148
149
150
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


151
def run_hf(
152
    requests: list[SampleRequest],
153
154
155
156
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
157
    trust_remote_code: bool,
158
    disable_detokenize: bool = False,
159
) -> float:
160
161
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
162
163
164
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
165
166
167
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
168
    start = time.perf_counter()
169
    batch: list[str] = []
170
171
172
173
174
175
176
177
178
179
180
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
181
182
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
183
184
185
186
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
187
188
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
189
190
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
191
            do_sample=True,
192
193
194
195
196
197
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
198
199
200
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
201
202
203
204
205
206
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
207
    end = time.perf_counter()
208
209
210
    return end - start


211
def run_mii(
212
    requests: list[SampleRequest],
213
214
215
216
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
217
218
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
219
    prompts = [request.prompt for request in requests]
220
221

    start = time.perf_counter()
222
    llm.generate(prompts, max_new_tokens=output_len)
223
    end = time.perf_counter()
224
225
    client = client(model)
    client.terminate_server()
226
227
228
    return end - start


229
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
230
                                     results: dict[str, Any]) -> None:
231
232
233
234
235
236
237
238
239
240
241
242
243
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
244
        write_to_json(pt_file, pt_records)
245
246


247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
            "Tokenizer/model must have chat template for sonnet dataset.")
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


282
def main(args: argparse.Namespace):
283
284
    if args.seed is None:
        args.seed = 0
285
286
287
    print(args)
    random.seed(args.seed)
    # Sample the requests.
288
289
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
290
    requests = get_requests(args, tokenizer)
291
292
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
    if args.backend == "vllm":
294
        if args.async_engine:
295
296
297
298
299
300
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
301
                    args.disable_detokenize,
302
                ))
303
        else:
304
            elapsed_time = run_vllm(requests, args.n,
305
306
                                    EngineArgs.from_cli_args(args),
                                    args.disable_detokenize)
307
308
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
309
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
310
311
                              args.hf_max_batch_size, args.trust_remote_code,
                              args.disable_detokenize)
312
313
314
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
315
316
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
317
318
319
320
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                              for request in requests)
321
322
323
324
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
325
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
Woosuk Kwon's avatar
Woosuk Kwon committed
326
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
327
328
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
329

330
331
332
333
334
335
336
337
338
339
340
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
341
        save_to_pytorch_benchmark_format(args, results)
342

343
344

if __name__ == "__main__":
345
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
346
347
    parser.add_argument("--backend",
                        type=str,
348
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
349
                        default="vllm")
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    parser.add_argument("--dataset-name",
                        type=str,
                        choices=["sharegpt", "random", "sonnet", "burstgpt"],
                        help="Name of the dataset to benchmark on.",
                        default="sharegpt")
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
        "list[dict[..., value: <prompt_or_response>]]]]")
    parser.add_argument("--dataset-path",
364
                        type=str,
365
                        default=None,
366
                        help="Path to the dataset")
367
368
369
370
371
372
373
374
375
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
376
377
378
    parser.add_argument("--n",
                        type=int,
                        default=1,
379
                        help="Number of generated sequences per prompt.")
380
381
382
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
383
                        help="Number of prompts to process.")
384
385
386
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
387
                        help="Maximum batch size for HF backend.")
388
389
390
391
392
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
393
394
395
396
397
398
399
400
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
401
402
403
404
405
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=("Do not detokenize the response (i.e. do not include "
              "detokenization time in the measurement)"))
406
407
408
409
410
411
412
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")
413
414
415
416
417
418
419
420
421
422
423
424
425
    parser.add_argument("--prefix-len",
                        type=int,
                        default=None,
                        help="Number of prefix tokens per request."
                        "This is for the RandomDataset and SonnetDataset")
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
        default=1.0,
        help="Range of sampled ratio of input/output length, "
        "used only for RandomDataSet.",
    )
426

427
    parser = AsyncEngineArgs.add_cli_args(parser)
428
    args = parser.parse_args()
429
430
    if args.tokenizer is None:
        args.tokenizer = args.model
431
432
433
434
435
436
437
438
439
440
441
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next "
            "release. Please use '--dataset-name' and "
            "'--dataset-path' in the future runs.",
            stacklevel=2)
        args.dataset_path = args.dataset
    if args.dataset is None and args.dataset_path is None:
        # for random dataset, the default sampling setting is in
        # benchmark_dataset.RandomDataset
        print("When dataset is not set, it will default to random dataset")
442
443
    else:
        assert args.input_len is None
444
445
    if args.enable_lora:
        assert args.lora_path is not None
446

Woosuk Kwon's avatar
Woosuk Kwon committed
447
    if args.backend == "vllm":
448
449
450
451
452
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
453
454
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
455
456
457
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
458
459
460
461
462
463
464
465
466
467
468
469
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
470
471
472
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
473
    main(args)