benchmark_throughput.py 14.3 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
import argparse
3
import dataclasses
4
5
6
import json
import random
import time
7
from typing import List, Optional
8

9
import torch
10
import uvloop
11
from tqdm import tqdm
12
13
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
14

15
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
16
17
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
18
19
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
20
from vllm.sampling_params import BeamSearchParams
21
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
22

23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@dataclasses.dataclass
class SampleRequest:
    """A class representing a single inference request for benchmarking.

    Attributes:
        prompt: The input text prompt for the model.
        multi_modal_data: Optional dictionary containing multi-modal data (e.g.
            images).
        prompt_len: The length of the prompt in tokens.
        expected_output_len: The expected length of the output in tokens.
    """
    prompt: str
    prompt_len: int
    expected_output_len: int
    multi_modal_data: Optional[MultiModalDataDict] = None


41
42
43
44
def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
45
    fixed_output_len: Optional[int],
46
) -> List[SampleRequest]:
47
48
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
49

50
51
52
53
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
54
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
55
    # Only keep the first two turns of each conversation.
56
57
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
58

59
60
    # Shuffle the dataset.
    random.shuffle(dataset)
61

62
    # Filter out sequences that are too long or too short
63
    filtered_dataset: List[SampleRequest] = []
64
65
66
67
68
69
70
71
72
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
73
        prompt_len = len(prompt_token_ids)
74
75
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
76
77
78
79
80
81
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
82
83
84
85
        filtered_dataset.append(
            SampleRequest(prompt=prompt,
                          prompt_len=prompt_len,
                          expected_output_len=output_len))
86

87
    return filtered_dataset
88
89


Woosuk Kwon's avatar
Woosuk Kwon committed
90
def run_vllm(
91
    requests: List[SampleRequest],
92
    n: int,
93
    engine_args: EngineArgs,
94
) -> float:
95
    from vllm import LLM, SamplingParams
96
    llm = LLM(**dataclasses.asdict(engine_args))
97

Zhuohan Li's avatar
Zhuohan Li committed
98
    # Add the requests to the engine.
99
    prompts: List[TextPrompt] = []
100
    sampling_params: List[SamplingParams] = []
101
102
    for request in requests:
        prompts.append(TextPrompt(prompt=request.prompt))
103
104
105
        sampling_params.append(
            SamplingParams(
                n=n,
106
                temperature=1.0,
107
108
                top_p=1.0,
                ignore_eos=True,
109
                max_tokens=request.expected_output_len,
110
            ))
111

112
113
114
    use_beam_search = False

    if not use_beam_search:
115
116
117
118
        start = time.perf_counter()
        llm.generate(prompts, sampling_params, use_tqdm=True)
        end = time.perf_counter()
    else:
119
        prompts = [request.prompt for request in requests]
120
121
        # output_len should be the same for all requests.
        output_len = requests[0][2]
122
123
        for request in requests:
            assert request.expected_output_len == output_len
124
        start = time.perf_counter()
125
126
127
128
129
130
131
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
132
        end = time.perf_counter()
133
134
135
    return end - start


136
async def run_vllm_async(
137
    requests: List[SampleRequest],
138
    n: int,
139
    engine_args: AsyncEngineArgs,
140
141
142
143
144
145
146
147
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
148
        prompts: List[TextPrompt] = []
149
        sampling_params: List[SamplingParams] = []
150
151
        for request in requests:
            prompts.append(TextPrompt(prompt=request.prompt))
152
153
154
            sampling_params.append(
                SamplingParams(
                    n=n,
155
                    temperature=1.0,
156
157
                    top_p=1.0,
                    ignore_eos=True,
158
                    max_tokens=request.expected_output_len,
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                ))

        generators = []
        start = time.perf_counter()
        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
            generator = llm.generate(prompt, sp, request_id=f"test{i}")
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


173
def run_hf(
174
    requests: List[SampleRequest],
175
176
177
178
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
179
    trust_remote_code: bool,
180
) -> float:
181
182
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
183
184
185
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
186
187
188
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
189
    start = time.perf_counter()
190
191
192
193
194
195
196
197
198
199
200
201
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
202
203
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
204
205
206
207
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
208
209
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
210
211
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
212
            do_sample=True,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
227
    end = time.perf_counter()
228
229
230
    return end - start


231
def run_mii(
232
    requests: List[SampleRequest],
233
234
235
236
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
237
238
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
239
    prompts = [request.prompt for request in requests]
240
241

    start = time.perf_counter()
242
    llm.generate(prompts, max_new_tokens=output_len)
243
    end = time.perf_counter()
244
245
    client = client(model)
    client.terminate_server()
246
247
248
    return end - start


249
250
251
252
253
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
254
255
256
257
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
258
259
260
261
262
263
264
265
266
267
        # As tokenizer may add additional tokens like BOS, we need to try
        # different lengths to get the desired input length.
        for i in range(-10, 10):
            prompt = "hi " * (args.input_len + i)
            tokenized_prompt = tokenizer(prompt).input_ids
            if len(tokenized_prompt) == args.input_len:
                break
        else:
            raise ValueError(
                f"Failed to synthesize a prompt with {args.input_len} tokens.")
268
269
270
271
272
273
        requests = [
            SampleRequest(prompt=prompt,
                          prompt_len=args.input_len,
                          expected_output_len=args.output_len)
            for _ in range(args.num_prompts)
        ]
274
275
276
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
277

Woosuk Kwon's avatar
Woosuk Kwon committed
278
    if args.backend == "vllm":
279
        if args.async_engine:
280
281
282
283
284
285
286
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
287
        else:
288
289
            elapsed_time = run_vllm(requests, args.n,
                                    EngineArgs.from_cli_args(args))
290
291
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
292
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
293
                              args.hf_max_batch_size, args.trust_remote_code)
294
295
296
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
297
298
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
299
300
301
302
    total_num_tokens = sum(request.prompt_len + request.expected_output_len
                           for request in requests)
    total_output_tokens = sum(request.expected_output_len
                              for request in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
303
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
304
305
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
306

307
308
309
310
311
312
313
314
315
316
317
318
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

319
320

if __name__ == "__main__":
321
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
322
323
    parser.add_argument("--backend",
                        type=str,
324
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
325
                        default="vllm")
326
327
    parser.add_argument("--dataset",
                        type=str,
328
                        default=None,
329
330
331
                        help="Path to the dataset. The dataset is expected to "
                        "be a json in form of List[Dict[..., conversations: "
                        "List[Dict[..., value: <prompt_or_response>]]]]")
332
333
334
335
336
337
338
339
340
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
341
342
343
    parser.add_argument("--n",
                        type=int,
                        default=1,
344
                        help="Number of generated sequences per prompt.")
345
346
347
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
348
                        help="Number of prompts to process.")
349
350
351
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
352
                        help="Maximum batch size for HF backend.")
353
354
355
356
357
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
358
359
360
361
362
363
364
365
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
366
    parser = AsyncEngineArgs.add_cli_args(parser)
367
    args = parser.parse_args()
368
369
370
371
372
373
374
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
375

Woosuk Kwon's avatar
Woosuk Kwon committed
376
    if args.backend == "vllm":
377
378
379
380
381
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
382
383
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
384
385
386
387
388
389
390
391
392
393
394
395
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
396
    main(args)