benchmark_throughput.py 14.9 KB
Newer Older
1
2
"""Benchmark offline inference throughput."""
import argparse
3
import dataclasses
4
5
6
7
8
9
10
import json
import random
import time
from typing import List, Optional, Tuple

import numpy as np
import torch
zhuwenwen's avatar
zhuwenwen committed
11
import uvloop
12
13
14
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
zhuwenwen's avatar
zhuwenwen committed
15

16
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
zhuwenwen's avatar
zhuwenwen committed
17
18
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
19
from vllm.sampling_params import BeamSearchParams
zhuwenwen's avatar
zhuwenwen committed
20
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: List[Tuple[str, int, int]] = []
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    return filtered_dataset


def run_vllm(
    warmup_requests: List[Tuple[str, int, int]],
    requests: List[Tuple[str, int, int]],
    n: int,
73
    engine_args: EngineArgs,
74
75
) -> float:
    from vllm import LLM, SamplingParams
76
    llm = LLM(**dataclasses.asdict(engine_args))
77
78

    # Add the requests to the engine.
79
80
    prompts: List[str] = []
    sampling_params: List[SamplingParams] = []
81
82
83
84
85
    for prompt, _, output_len in requests:
        prompts.append(prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
86
                temperature=1.0,
87
88
89
90
                top_p=1.0,
                ignore_eos=True,
                max_tokens=output_len,
            ))
91
        
92
93
94
95
96
97
98
99
    # warmup
    warmup_prompts = []
    warmup_sampling_params = []
    for prompt, _, output_len in warmup_requests:
        warmup_prompts.append(prompt)
        warmup_sampling_params.append(
            SamplingParams(
                n=n,
100
                temperature=1.0,
101
102
103
104
105
106
107
108
109
110
111
112
                top_p=1.0,
                ignore_eos=True,
                max_tokens=output_len,
            ))
        
    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
    
    # dummy_prompt_token_ids = np.random.randint(10000,
    #                                            size=(args.num_prompts,
    #                                                  args.input_len))
113
    # dummy_prompts: List[PromptType] = [{
114
115
116
117
118
119
120
121
122
123
124
    #     "prompt_token_ids": batch
    # } for batch in dummy_prompt_token_ids.tolist()]

    # def run_to_completion():
    #     llm.generate(dummy_inputs,
    #                     sampling_params=sampling_params,
    #                     use_tqdm=False)

    # print("Warming up...")
    # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
    #     run_to_completion()
zhuwenwen's avatar
zhuwenwen committed
125

126
127
128
129
    
    use_beam_search = False

    if not use_beam_search:
zhuwenwen's avatar
zhuwenwen committed
130
131
132
133
134
135
136
137
138
139
        start = time.perf_counter()
        llm.generate(prompts, sampling_params, use_tqdm=True)
        end = time.perf_counter()
    else:
        prompts = [prompt for prompt, _, _ in requests]
        # output_len should be the same for all requests.
        output_len = requests[0][2]
        for prompt, input_len, _output_len in requests:
            assert _output_len == output_len
        start = time.perf_counter()
140
141
142
143
144
145
146
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
zhuwenwen's avatar
zhuwenwen committed
147
        end = time.perf_counter()
148
149
150
    return end - start


zhuwenwen's avatar
zhuwenwen committed
151
152
153
async def run_vllm_async(
    requests: List[Tuple[str, int, int]],
    n: int,
154
    engine_args: AsyncEngineArgs,
zhuwenwen's avatar
zhuwenwen committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    disable_frontend_multiprocessing: bool = False,
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:

        # Add the requests to the engine.
        prompts: List[str] = []
        sampling_params: List[SamplingParams] = []
        for prompt, _, output_len in requests:
            prompts.append(prompt)
            sampling_params.append(
                SamplingParams(
                    n=n,
170
                    temperature=1.0,
zhuwenwen's avatar
zhuwenwen committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                    top_p=1.0,
                    ignore_eos=True,
                    max_tokens=output_len,
                ))

        generators = []
        start = time.perf_counter()
        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
            generator = llm.generate(prompt, sp, request_id=f"test{i}")
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
    trust_remote_code: bool,
) -> float:
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.perf_counter()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
227
            do_sample=True,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.perf_counter()
    return end - start


def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
    llm.generate(prompts, max_new_tokens=output_len)
    end = time.perf_counter()
    client = client(model)
    client.terminate_server()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
271
272
273
    warmup_prompt = "hi" * 10
    warmup_requests = [(warmup_prompt, 10, 10)
                for _ in range(1)]
274
275
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
276
277
278
279
280
281
282
283
284
285
        # As tokenizer may add additional tokens like BOS, we need to try
        # different lengths to get the desired input length.
        for i in range(-10, 10):
            prompt = "hi " * (args.input_len + i)
            tokenized_prompt = tokenizer(prompt).input_ids
            if len(tokenized_prompt) == args.input_len:
                break
        else:
            raise ValueError(
                f"Failed to synthesize a prompt with {args.input_len} tokens.")
286
287
288
289
290
291
292
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)

    if args.backend == "vllm":
zhuwenwen's avatar
zhuwenwen committed
293
        if args.async_engine:
294
295
296
297
298
299
300
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
                ))
zhuwenwen's avatar
zhuwenwen committed
301
        else:
302
303
            elapsed_time = run_vllm(warmup_requests, requests, args.n,
                                    EngineArgs.from_cli_args(args))
304
305
306
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
307
                              args.hf_max_batch_size, args.trust_remote_code)
308
309
310
311
312
313
314
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
zhuwenwen's avatar
zhuwenwen committed
315
    
zhuwenwen's avatar
zhuwenwen committed
316
317
318
319
320
321
322
323
    if args.dataset is None:
        total_out_tokens = args.output_len * args.num_prompts
    else:
        total_out_tokens = sum(output_len for _, _, output_len in requests) 
    print(f"Latency: {elapsed_time:.2f} s")
    print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
    print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339


    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)


if __name__ == "__main__":
340
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf", "mii"],
                        default="vllm")
    parser.add_argument("--dataset",
                        type=str,
                        default=None,
                        help="Path to the dataset.")
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
    parser.add_argument("--n",
                        type=int,
                        default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=1,
                        help='Number of iterations to run for warmup.')
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
                        help="Maximum batch size for HF backend.")
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
zhuwenwen's avatar
zhuwenwen committed
379
380
381
382
383
384
385
386
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
387
    parser = AsyncEngineArgs.add_cli_args(parser)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None

    if args.backend == "vllm":
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
    main(args)