benchmark_throughput.py 16.3 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
import argparse
3
import dataclasses
4
5
6
import json
import random
import time
7
from typing import List, Optional
8

9
import torch
10
import uvloop
11
from PIL import Image
12
from tqdm import tqdm
13
14
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
15

16
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
17
18
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
19
20
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
21
from vllm.sampling_params import BeamSearchParams
22
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
23

24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
    """Prepend and append special tokens around the question to form a prompt.

    Args:
        question: The input question text to wrap with special tokens
        model: The name of the model being used, to determine which special
            tokens to add

    Returns:
        The formatted prompt string with appropriate special tokens for the
            model

    Raises:
        ValueError: If an unsupported model name is provided
    """
    model = model.lower()
    if "pixtral" in model:
        return f"<s>[INST]{question}\n[IMG][/INST]"
    raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
                    args: argparse.Namespace) -> List[SampleRequest]:
    dataset_path: str = args.dataset
    num_requests: int = args.num_prompts
    fixed_output_len: Optional[int] = args.output_len
    model: str = args.model
69
70
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
71

72
73
74
75
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
76
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
77
78
    # Shuffle the dataset.
    random.shuffle(dataset)
79

80
    # Filter out sequences that are too long or too short
81
    filtered_dataset: List[SampleRequest] = []
82
    for data in dataset:
83
84
85
        if len(filtered_dataset) == num_requests:
            break

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        # Only keep the first two turns of each conversation.
        prompt = data["conversations"][0]["value"]
        completion = data["conversations"][1]["value"]

        multi_modal_data: Optional[MultiModalDataDict] = None
        if "image" in data:
            multi_modal_data = multi_modal_data or {}
            image_path = data["image"]
            # TODO(vllm-project/vllm/issues/9778): Support multiple images.
            assert isinstance(image_path,
                              str), "Only support single image input"
            try:
                multi_modal_data["image"] = Image.open(image_path).convert(
                    "RGB")
            except FileNotFoundError:
                # Ignore datapoint where asset is missing
                continue
            prompt = _get_prompt_for_image_model(question=prompt, model=model)

105
106
107
        # Tokenize the prompts and completions.
        prompt_token_ids = tokenizer(prompt).input_ids
        completion_token_ids = tokenizer(completion).input_ids
108
        prompt_len = len(prompt_token_ids)
109
110
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
111
112
113
114
115
116
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
117
118
119
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
120
121
                          expected_output_len=output_len,
                          multi_modal_data=multi_modal_data))
122

123
    return filtered_dataset
124
125


Woosuk Kwon's avatar
Woosuk Kwon committed
126
def run_vllm(
127
    requests: List[SampleRequest],
128
    n: int,
129
    engine_args: EngineArgs,
130
) -> float:
131
    from vllm import LLM, SamplingParams
132
    llm = LLM(**dataclasses.asdict(engine_args))
133

Zhuohan Li's avatar
Zhuohan Li committed
134
    # Add the requests to the engine.
135
    prompts: List[TextPrompt] = []
136
    sampling_params: List[SamplingParams] = []
137
    for request in requests:
138
139
140
        prompts.append(
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
141
142
143
        sampling_params.append(
            SamplingParams(
                n=n,
144
                temperature=1.0,
145
146
                top_p=1.0,
                ignore_eos=True,
147
                max_tokens=request.expected_output_len,
148
            ))
149

150
151
152
    use_beam_search = False

    if not use_beam_search:
153
154
155
156
        start = time.perf_counter()
        llm.generate(prompts, sampling_params, use_tqdm=True)
        end = time.perf_counter()
    else:
157
        prompts = [request.prompt for request in requests]
158
159
        # output_len should be the same for all requests.
        output_len = requests[0][2]
160
161
        for request in requests:
            assert request.expected_output_len == output_len
162
        start = time.perf_counter()
163
164
165
166
167
168
169
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
170
        end = time.perf_counter()
171
172
173
    return end - start


174
async def run_vllm_async(
175
    requests: List[SampleRequest],
176
    n: int,
177
    engine_args: AsyncEngineArgs,
178
179
180
181
182
183
184
185
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
186
        prompts: List[TextPrompt] = []
187
        sampling_params: List[SamplingParams] = []
188
        for request in requests:
189
190
191
            prompts.append(
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
192
193
194
            sampling_params.append(
                SamplingParams(
                    n=n,
195
                    temperature=1.0,
196
197
                    top_p=1.0,
                    ignore_eos=True,
198
                    max_tokens=request.expected_output_len,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                ))

        generators = []
        start = time.perf_counter()
        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
            generator = llm.generate(prompt, sp, request_id=f"test{i}")
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


213
def run_hf(
214
    requests: List[SampleRequest],
215
216
217
218
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
219
    trust_remote_code: bool,
220
) -> float:
221
222
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
223
224
225
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
226
227
228
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
229
    start = time.perf_counter()
230
231
232
233
234
235
236
237
238
239
240
241
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
242
243
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
244
245
246
247
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
248
249
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
250
251
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
252
            do_sample=True,
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
267
    end = time.perf_counter()
268
269
270
    return end - start


271
def run_mii(
272
    requests: List[SampleRequest],
273
274
275
276
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
277
278
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
279
    prompts = [request.prompt for request in requests]
280
281

    start = time.perf_counter()
282
    llm.generate(prompts, max_new_tokens=output_len)
283
    end = time.perf_counter()
284
285
    client = client(model)
    client.terminate_server()
286
287
288
    return end - start


289
290
291
292
293
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
294
295
296
297
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
298
299
300
301
302
303
304
305
306
307
        # As tokenizer may add additional tokens like BOS, we need to try
        # different lengths to get the desired input length.
        for i in range(-10, 10):
            prompt = "hi " * (args.input_len + i)
            tokenized_prompt = tokenizer(prompt).input_ids
            if len(tokenized_prompt) == args.input_len:
                break
        else:
            raise ValueError(
                f"Failed to synthesize a prompt with {args.input_len} tokens.")
308
309
310
311
312
313
        requests = [
            SampleRequest(prompt=prompt,
                          prompt_len=args.input_len,
                          expected_output_len=args.output_len)
            for _ in range(args.num_prompts)
        ]
314
    else:
315
        requests = sample_requests(tokenizer, args)
316

317
318
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
319
    if args.backend == "vllm":
320
        if args.async_engine:
321
322
323
324
325
326
327
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
328
        else:
329
330
            elapsed_time = run_vllm(requests, args.n,
                                    EngineArgs.from_cli_args(args))
331
332
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
333
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
334
                              args.hf_max_batch_size, args.trust_remote_code)
335
336
337
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
338
339
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
340
341
342
343
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                              for request in requests)
344
345
346
347
348
    if is_multi_modal:
        print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
        # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
Woosuk Kwon's avatar
Woosuk Kwon committed
349
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
350
351
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
352

353
354
355
356
357
358
359
360
361
362
363
364
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

365
366

if __name__ == "__main__":
367
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
368
369
    parser.add_argument("--backend",
                        type=str,
370
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
371
                        default="vllm")
372
373
    parser.add_argument("--dataset",
                        type=str,
374
                        default=None,
375
376
377
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
378
379
380
381
382
383
384
385
386
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
387
388
389
    parser.add_argument("--n",
                        type=int,
                        default=1,
390
                        help="Number of generated sequences per prompt.")
391
392
393
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
394
                        help="Number of prompts to process.")
395
396
397
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
398
                        help="Maximum batch size for HF backend.")
399
400
401
402
403
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
404
405
406
407
408
409
410
411
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
412
    parser = AsyncEngineArgs.add_cli_args(parser)
413
    args = parser.parse_args()
414
415
416
417
418
419
420
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
421

Woosuk Kwon's avatar
Woosuk Kwon committed
422
    if args.backend == "vllm":
423
424
425
426
427
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
428
429
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
430
431
432
433
434
435
436
437
438
439
440
441
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
442
    main(args)