benchmark_throughput.py 20.1 KB
Newer Older
1
2
"""Benchmark offline inference throughput."""
import argparse
zhuwenwen's avatar
zhuwenwen committed
3
import dataclasses
4
5
6
import json
import random
import time
zhuwenwen's avatar
zhuwenwen committed
7
from typing import List, Optional
8
9
10

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
11
import uvloop
zhuwenwen's avatar
zhuwenwen committed
12
from PIL import Image
13
14
15
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
zhuwenwen's avatar
zhuwenwen committed
16

zhuwenwen's avatar
zhuwenwen committed
17
18
19

from vllm.inputs import PromptType
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
20
21
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
zhuwenwen's avatar
zhuwenwen committed
22
23
24
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
zhuwenwen's avatar
zhuwenwen committed
25
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
26
27


zhuwenwen's avatar
zhuwenwen committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
    """Prepend and append special tokens around the question to form a prompt.

    Args:
        question: The input question text to wrap with special tokens
        model: The name of the model being used, to determine which special
            tokens to add

    Returns:
        The formatted prompt string with appropriate special tokens for the
            model

    Raises:
        ValueError: If an unsupported model name is provided
    """
    model = model.lower()
    if "pixtral" in model:
        return f"<s>[INST]{question}\n[IMG][/INST]"
    raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
                    args: argparse.Namespace) -> List[SampleRequest]:
    dataset_path: str = args.dataset
    num_requests: int = args.num_prompts
    fixed_output_len: Optional[int] = args.output_len
    model: str = args.model
72
73
74
75
76
77
78
79
80
81
82
83
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
zhuwenwen's avatar
zhuwenwen committed
84
85
    filtered_dataset: List[SampleRequest] = []
    for data in dataset:
86
87
88
        if len(filtered_dataset) == num_requests:
            break

zhuwenwen's avatar
zhuwenwen committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        # Only keep the first two turns of each conversation.
        prompt = data["conversations"][0]["value"]
        completion = data["conversations"][1]["value"]

        multi_modal_data: Optional[MultiModalDataDict] = None
        if "image" in data:
            multi_modal_data = multi_modal_data or {}
            image_path = data["image"]
            # TODO(vllm-project/vllm/issues/9778): Support multiple images.
            assert isinstance(image_path,
                              str), "Only support single image input"
            try:
                multi_modal_data["image"] = Image.open(image_path).convert(
                    "RGB")
            except FileNotFoundError:
                # Ignore datapoint where asset is missing
                continue
            prompt = _get_prompt_for_image_model(question=prompt, model=model)

108
109
110
111
112
113
114
115
116
117
118
119
        # Tokenize the prompts and completions.
        prompt_token_ids = tokenizer(prompt).input_ids
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
zhuwenwen's avatar
zhuwenwen committed
120
121
122
123
124
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
                          expected_output_len=output_len,
                          multi_modal_data=multi_modal_data))
125
126
127
128
129

    return filtered_dataset


def run_vllm(
zhuwenwen's avatar
zhuwenwen committed
130
131
    warmup_requests: List[SampleRequest],
    requests: List[SampleRequest],
132
    n: int,
zhuwenwen's avatar
zhuwenwen committed
133
    engine_args: EngineArgs,
134
135
) -> float:
    from vllm import LLM, SamplingParams
zhuwenwen's avatar
zhuwenwen committed
136
    llm = LLM(**dataclasses.asdict(engine_args))
137
138

    # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
139
    prompts: List[TextPrompt] = []
140
    sampling_params: List[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
141
142
143
144
    for request in requests:
        prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
145
146
147
        sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
148
                temperature=1.0,
149
150
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
151
                max_tokens=request.expected_output_len,
152
153
154
            ))

    # warmup
zhuwenwen's avatar
zhuwenwen committed
155
156
157
158
159
160
    warmup_prompts: List[TextPrompt] = []
    warmup_sampling_params: List[SamplingParams] = []
    for request in warmup_prompts:
        warmup_prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
161
162
163
        warmup_sampling_params.append(
            SamplingParams(
                n=n,
zhuwenwen's avatar
zhuwenwen committed
164
                temperature=1.0,
165
166
                top_p=1.0,
                ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
167
                max_tokens=request.expected_output_len,
168
169
170
171
172
173
174
175
176
            ))
        
    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
    
    # dummy_prompt_token_ids = np.random.randint(10000,
    #                                            size=(args.num_prompts,
    #                                                  args.input_len))
zhuwenwen's avatar
zhuwenwen committed
177
    # dummy_prompts: List[PromptType] = [{
178
179
    #     "prompt_token_ids": batch
    # } for batch in dummy_prompt_token_ids.tolist()]
zhuwenwen's avatar
zhuwenwen committed
180
181
182
    
    # def run_to_completion(profile_dir: Optional[str] = None):
    #     llm.generate(dummy_prompts,
183
184
    #                     sampling_params=sampling_params,
    #                     use_tqdm=False)
zhuwenwen's avatar
zhuwenwen committed
185
    
186
187
    # print("Warming up...")
    # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
zhuwenwen's avatar
zhuwenwen committed
188
    #     run_to_completion(profile_dir=None)
zhuwenwen's avatar
zhuwenwen committed
189

zhuwenwen's avatar
zhuwenwen committed
190
191
192
    use_beam_search = False

    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
193
194
195
196
        start = time.perf_counter()
        llm.generate(prompts, sampling_params, use_tqdm=True)
        end = time.perf_counter()
    else:
zhuwenwen's avatar
zhuwenwen committed
197
        prompts = [request.prompt for request in requests]
zhuwenwen's avatar
zhuwenwen committed
198
199
        # output_len should be the same for all requests.
        output_len = requests[0][2]
zhuwenwen's avatar
zhuwenwen committed
200
201
        for request in requests:
            assert request.expected_output_len == output_len
zhuwenwen's avatar
zhuwenwen committed
202
        start = time.perf_counter()
zhuwenwen's avatar
zhuwenwen committed
203
204
205
206
207
208
209
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
zhuwenwen's avatar
zhuwenwen committed
210
        end = time.perf_counter()
211
212
213
    return end - start


zhuwenwen's avatar
zhuwenwen committed
214
async def run_vllm_async(
zhuwenwen's avatar
zhuwenwen committed
215
    requests: List[SampleRequest],
zhuwenwen's avatar
zhuwenwen committed
216
    n: int,
zhuwenwen's avatar
zhuwenwen committed
217
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
218
219
220
221
222
223
224
225
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
zhuwenwen's avatar
zhuwenwen committed
226
        prompts: List[TextPrompt] = []
zhuwenwen's avatar
zhuwenwen committed
227
        sampling_params: List[SamplingParams] = []
zhuwenwen's avatar
zhuwenwen committed
228
229
230
231
        for request in requests:
            prompts.append(
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
zhuwenwen's avatar
zhuwenwen committed
232
233
234
            sampling_params.append(
                SamplingParams(
                    n=n,
zhuwenwen's avatar
zhuwenwen committed
235
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
236
237
                    top_p=1.0,
                    ignore_eos=True,
zhuwenwen's avatar
zhuwenwen committed
238
                    max_tokens=request.expected_output_len,
zhuwenwen's avatar
zhuwenwen committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                ))

        generators = []
        start = time.perf_counter()
        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
            generator = llm.generate(prompt, sp, request_id=f"test{i}")
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


253
def run_hf(
zhuwenwen's avatar
zhuwenwen committed
254
    requests: List[SampleRequest],
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
zhuwenwen's avatar
zhuwenwen committed
292
            do_sample=True,
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
zhuwenwen's avatar
zhuwenwen committed
312
    requests: List[SampleRequest],
313
314
315
316
317
318
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
zhuwenwen's avatar
zhuwenwen committed
319
    prompts = [request.prompt for request in requests]
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
336
337
338
    warmup_prompt = "hi" * 10
    warmup_requests = [(warmup_prompt, 10, 10)
                for _ in range(1)]
339
    if args.dataset is None:
zhuwenwen's avatar
zhuwenwen committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        vocab_size = tokenizer.vocab_size
        requests = []
        for _ in range(args.num_prompts):
            # Synthesize a prompt with the given input length.
            candidate_ids = [
                random.randint(0, vocab_size - 1)
                for _ in range(args.input_len)
            ]
            # As tokenizer may add additional tokens like BOS, we need to try
            # different lengths to get the desired input length.
            for _ in range(5):  # Max attempts to correct
                candidate_prompt = tokenizer.decode(candidate_ids)
                tokenized_len = len(tokenizer.encode(candidate_prompt))

                if tokenized_len == args.input_len:
                    break

                # Adjust length based on difference
                diff = args.input_len - tokenized_len
                if diff > 0:
                    candidate_ids.extend([
                        random.randint(100, vocab_size - 100)
                        for _ in range(diff)
                    ])
                else:
                    candidate_ids = candidate_ids[:diff]
            requests.append(
                SampleRequest(prompt=candidate_prompt,
                              prompt_len=args.input_len,
                              expected_output_len=args.output_len))
370
    else:
zhuwenwen's avatar
zhuwenwen committed
371
        requests = sample_requests(tokenizer, args)
372

zhuwenwen's avatar
zhuwenwen committed
373
374
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
375
    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        # if args.async_engine:
        #     run_args = [
        #         requests, args.model, args.tokenizer, args.quantization,
        #         args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
        #         args.trust_remote_code, args.dtype, args.max_model_len,
        #         args.enforce_eager, args.kv_cache_dtype,
        #         args.quantization_param_path, args.device,
        #         args.enable_prefix_caching, args.enable_chunked_prefill,
        #         args.max_num_batched_tokens, args.distributed_executor_backend,
        #         args.gpu_memory_utilization, args.num_scheduler_steps,
        #         args.use_v2_block_manager, args.download_dir, args.load_format,
        #         args.disable_async_output_proc
        #     ]
        # else:
        #     run_args = [
        #         warmup_requests, requests, args.model, args.tokenizer, args.quantization,
        #         args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
        #         args.trust_remote_code, args.dtype, args.max_model_len,
        #         args.enforce_eager, args.kv_cache_dtype,
        #         args.quantization_param_path, args.device,
        #         args.enable_prefix_caching, args.enable_chunked_prefill,
        #         args.max_num_batched_tokens, args.distributed_executor_backend,
        #         args.gpu_memory_utilization, args.num_scheduler_steps,
        #         args.use_v2_block_manager, args.download_dir, args.load_format,
        #         args.disable_async_output_proc
        #     ]
zhuwenwen's avatar
zhuwenwen committed
402
403

        if args.async_engine:
zhuwenwen's avatar
zhuwenwen committed
404
405
406
407
408
409
410
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
zhuwenwen's avatar
zhuwenwen committed
411
        else:
zhuwenwen's avatar
zhuwenwen committed
412
413
            elapsed_time = run_vllm(requests, args.n,
                                    EngineArgs.from_cli_args(args))
414
415
416
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
zhuwenwen's avatar
zhuwenwen committed
417
                              args.hf_max_batch_size, args.trust_remote_code)
418
419
420
421
422
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
zhuwenwen's avatar
zhuwenwen committed
423
424
425
426
427
428
429
430
431
432
433
434
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                            for request in requests)
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
        # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)


if __name__ == "__main__":
450
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
451
452
453
454
455
456
457
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf", "mii"],
                        default="vllm")
    parser.add_argument("--dataset",
                        type=str,
                        default=None,
zhuwenwen's avatar
zhuwenwen committed
458
459
460
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
                        help="Maximum batch size for HF backend.")
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
zhuwenwen's avatar
zhuwenwen committed
491
492
493
494
495
496
497
498
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
zhuwenwen's avatar
zhuwenwen committed
499
    parser = AsyncEngineArgs.add_cli_args(parser)
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None

    if args.backend == "vllm":
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
    main(args)