benchmark_throughput.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark offline inference throughput."""
3
import argparse
4
import dataclasses
5
6
7
import json
import random
import time
8
9
from functools import cache
from typing import Dict, List, Optional, Tuple
10

zhuwenwen's avatar
zhuwenwen committed
11
import numpy as np
12
import torch
13
import uvloop
14
from PIL import Image
15
from tqdm import tqdm
16
17
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
18

zhuwenwen's avatar
zhuwenwen committed
19
20

from vllm.inputs import PromptType
21
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
22
23
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
24
from vllm.inputs import TextPrompt
25
26
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
27
from vllm.multimodal import MultiModalDataDict
28
from vllm.sampling_params import BeamSearchParams
29
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
30
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
31

32

33
34
35
36
37
38
39
40
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
41
42
43
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        lora_request: Optional LoRARequest specifying the LoRA to use. 
44
45
46
47
48
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None
49
    lora_request: Optional[LoRARequest] = None
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
    """Prepend and append special tokens around the question to form a prompt.

    Args:
        question: The input question text to wrap with special tokens
        model: The name of the model being used, to determine which special
            tokens to add

    Returns:
        The formatted prompt string with appropriate special tokens for the
            model

    Raises:
        ValueError: If an unsupported model name is provided
    """
    model = model.lower()
    if "pixtral" in model:
        return f"<s>[INST]{question}\n[IMG][/INST]"
    raise ValueError(f"Unsupported model {model}")


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@cache
def lora_path_on_disk(lora_path: str) -> str:
    return get_adapter_absolute_path(lora_path)


lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}


def get_random_lora_request(
        args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
    global lora_tokenizer_cache
    lora_id = random.randint(1, args.max_loras)
    lora_request = LoRARequest(lora_name=str(lora_id),
                               lora_int_id=lora_id,
                               lora_path=lora_path_on_disk(args.lora_path))
    if lora_id not in lora_tokenizer_cache:
        lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
    return lora_request, lora_tokenizer_cache[lora_id]


94
95
def sample_requests(tokenizer: PreTrainedTokenizerBase,
                    args: argparse.Namespace) -> List[SampleRequest]:
96

97
98
99
100
    dataset_path: str = args.dataset
    num_requests: int = args.num_prompts
    fixed_output_len: Optional[int] = args.output_len
    model: str = args.model
101
102
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
103

104
105
106
107
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
108
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
109
110
    # Shuffle the dataset.
    random.shuffle(dataset)
111

112
    # Filter out sequences that are too long or too short
113
    filtered_dataset: List[SampleRequest] = []
114
115
116
    for data in tqdm(dataset,
                     total=len(filtered_dataset),
                     desc="sampling requests"):
117
118
119
        if len(filtered_dataset) == num_requests:
            break

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        # Only keep the first two turns of each conversation.
        prompt = data["conversations"][0]["value"]
        completion = data["conversations"][1]["value"]

        multi_modal_data: Optional[MultiModalDataDict] = None
        if "image" in data:
            multi_modal_data = multi_modal_data or {}
            image_path = data["image"]
            # TODO(vllm-project/vllm/issues/9778): Support multiple images.
            assert isinstance(image_path,
                              str), "Only support single image input"
            try:
                multi_modal_data["image"] = Image.open(image_path).convert(
                    "RGB")
            except FileNotFoundError:
                # Ignore datapoint where asset is missing
                continue
            prompt = _get_prompt_for_image_model(question=prompt, model=model)

139
140
141
142
143
144
145
        request_tokenizer = tokenizer
        lora_request: Optional[LoRARequest] = None
        if args.enable_lora:
            lora_request, lora_tokenizer = get_random_lora_request(args)
            if lora_tokenizer:
                request_tokenizer = lora_tokenizer

146
        # Tokenize the prompts and completions.
147
148
        prompt_token_ids = request_tokenizer(prompt).input_ids
        completion_token_ids = request_tokenizer(completion).input_ids
149
        prompt_len = len(prompt_token_ids)
150
151
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
152
153
154
155
156
157
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
158
159
160
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
161
                          expected_output_len=output_len,
162
163
                          multi_modal_data=multi_modal_data,
                          lora_request=lora_request))
164

165
    return filtered_dataset
166
167


Woosuk Kwon's avatar
Woosuk Kwon committed
168
def run_vllm(
169
    requests: List[SampleRequest],
170
    n: int,
171
    num_iters_warmup: int,
172
    engine_args: EngineArgs,
173
) -> float:
174
    from vllm import LLM, SamplingParams
175
    llm = LLM(**dataclasses.asdict(engine_args))
176

Zhuohan Li's avatar
Zhuohan Li committed
177
    # Add the requests to the engine.
178
    prompts: List[TextPrompt] = []
179
    sampling_params: List[SamplingParams] = []
180
    for request in requests:
181
182
183
        prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
184
185
186
        sampling_params.append(
            SamplingParams(
                n=n,
187
                temperature=1.0,
188
189
                top_p=1.0,
                ignore_eos=True,
190
                max_tokens=request.expected_output_len,
191
            ))
192
193
194
    lora_requests: Optional[List[LoRARequest]] = None
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
195

zhuwenwen's avatar
zhuwenwen committed
196
    # warmup
197
198
199
200
201
202
203
204
205
206
207
    warmup_sampling_params = SamplingParams(
        n=args.n,
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=10,
    )
    dummy_prompt_token_ids = np.random.randint(10000, size=(1,10))
    dummy_prompts: List[PromptType] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
zhuwenwen's avatar
zhuwenwen committed
208
    
209
210
211
212
213
    print("Warming up...")
    for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
        llm.generate(dummy_prompts,
                        sampling_params=warmup_sampling_params,
                        use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
214

215
216
217
    use_beam_search = False

    if not use_beam_search:
218
        start = time.perf_counter()
219
220
221
222
        llm.generate(prompts,
                     sampling_params,
                     lora_request=lora_requests,
                     use_tqdm=True)
223
224
        end = time.perf_counter()
    else:
225
        assert lora_requests is None, "BeamSearch API does not support LoRA"
226
        prompts = [request.prompt for request in requests]
227
228
        # output_len should be the same for all requests.
        output_len = requests[0][2]
229
230
        for request in requests:
            assert request.expected_output_len == output_len
231
        start = time.perf_counter()
232
233
234
235
236
237
238
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
239
        end = time.perf_counter()
240
241
242
    return end - start


243
async def run_vllm_async(
244
    requests: List[SampleRequest],
245
    n: int,
246
    engine_args: AsyncEngineArgs,
247
248
249
250
251
252
253
254
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
255
        prompts: List[TextPrompt] = []
256
        sampling_params: List[SamplingParams] = []
257
        lora_requests: List[Optional[LoRARequest]] = []
258
        for request in requests:
259
260
261
            prompts.append(
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
262
263
264
            sampling_params.append(
                SamplingParams(
                    n=n,
265
                    temperature=1.0,
266
267
                    top_p=1.0,
                    ignore_eos=True,
268
                    max_tokens=request.expected_output_len,
269
                ))
270
            lora_requests.append(request.lora_request)
271
272
273

        generators = []
        start = time.perf_counter()
274
275
276
277
278
279
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
280
281
282
283
284
285
286
287
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


288
def run_hf(
289
    requests: List[SampleRequest],
290
291
292
293
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
294
    trust_remote_code: bool,
295
) -> float:
296
297
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
298
299
300
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
301
302
303
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
304
    start = time.perf_counter()
305
306
307
308
309
310
311
312
313
314
315
316
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
317
318
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
319
320
321
322
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
323
324
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
325
326
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
327
            do_sample=True,
328
329
330
331
332
333
334
335
336
337
338
339
340
341
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
342
    end = time.perf_counter()
343
344
345
    return end - start


346
def run_mii(
347
    requests: List[SampleRequest],
348
349
350
351
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
352
353
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
354
    prompts = [request.prompt for request in requests]
355
356

    start = time.perf_counter()
357
    llm.generate(prompts, max_new_tokens=output_len)
358
    end = time.perf_counter()
359
360
    client = client(model)
    client.terminate_server()
361
362
363
    return end - start


364
365
366
367
368
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
369
370
371
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
372
373
374
        vocab_size = tokenizer.vocab_size
        requests = []
        for _ in range(args.num_prompts):
375
376
377
378
379
380
381
382

            request_tokenizer = tokenizer
            lora_request: Optional[LoRARequest] = None
            if args.enable_lora:
                lora_request, lora_tokenizer = get_random_lora_request(args)
                if lora_tokenizer:
                    request_tokenizer = lora_tokenizer

383
384
385
386
387
388
389
390
            # Synthesize a prompt with the given input length.
            candidate_ids = [
                random.randint(0, vocab_size - 1)
                for _ in range(args.input_len)
            ]
            # As tokenizer may add additional tokens like BOS, we need to try
            # different lengths to get the desired input length.
            for _ in range(5):  # Max attempts to correct
391
392
                candidate_prompt = request_tokenizer.decode(candidate_ids)
                tokenized_len = len(request_tokenizer.encode(candidate_prompt))
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

                if tokenized_len == args.input_len:
                    break

                # Adjust length based on difference
                diff = args.input_len - tokenized_len
                if diff > 0:
                    candidate_ids.extend([
                        random.randint(100, vocab_size - 100)
                        for _ in range(diff)
                    ])
                else:
                    candidate_ids = candidate_ids[:diff]
            requests.append(
                SampleRequest(prompt=candidate_prompt,
                              prompt_len=args.input_len,
409
410
                              expected_output_len=args.output_len,
                              lora_request=lora_request))
411
    else:
412
        requests = sample_requests(tokenizer, args)
413

414
415
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
416
    if args.backend == "vllm":
417
        if args.async_engine:
418
419
420
421
422
423
424
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
425
        else:
426
            elapsed_time = run_vllm(requests, args.n, args.num_iters_warmup,
427
                                    EngineArgs.from_cli_args(args))
428
429
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
430
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
431
                              args.hf_max_batch_size, args.trust_remote_code)
432
433
434
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
435
436
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
437
438
439
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
zhuwenwen's avatar
zhuwenwen committed
440
                            for request in requests)
441
442
443
444
445
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
        # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
Woosuk Kwon's avatar
Woosuk Kwon committed
446
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
447
448
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
449

450
451
452
453
454
455
456
457
458
459
460
461
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

462
463

if __name__ == "__main__":
464
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
465
466
    parser.add_argument("--backend",
                        type=str,
467
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
468
                        default="vllm")
469
470
    parser.add_argument("--dataset",
                        type=str,
471
                        default=None,
472
473
474
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
475
476
477
478
479
480
481
482
483
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
484
485
486
    parser.add_argument("--n",
                        type=int,
                        default=1,
487
                        help="Number of generated sequences per prompt.")
zhuwenwen's avatar
zhuwenwen committed
488
489
490
491
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
492
493
494
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
495
                        help="Number of prompts to process.")
496
497
498
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
499
                        help="Maximum batch size for HF backend.")
500
501
502
503
504
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
505
506
507
508
509
510
511
512
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
513
514
515
516
517
518
519
520
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")

521
    parser = AsyncEngineArgs.add_cli_args(parser)
522
    args = parser.parse_args()
523
524
525
526
527
528
529
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
530
531
    if args.enable_lora:
        assert args.lora_path is not None
532

Woosuk Kwon's avatar
Woosuk Kwon committed
533
    if args.backend == "vllm":
534
535
536
537
538
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
539
540
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
541
542
543
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
544
545
546
547
548
549
550
551
552
553
554
555
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
556
557
558
        if args.enable_lora is not None:
            raise ValueError("LoRA benchmarking is only supported for vLLM"
                             " backend")
559
    main(args)