"docs/vscode:/vscode.git/clone" did not exist on "23637dcdef9ecc39df6a0e33871ed48c5f9dfcbd"
benchmark_throughput.py 14.3 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

8
import torch
9
from tqdm import tqdm
10
11
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
12

13
14
15
16
17

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
18
    fixed_output_len: Optional[int],
19
) -> List[Tuple[str, int, int]]:
20
21
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")
22

23
24
25
26
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
27
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
28
    # Only keep the first two turns of each conversation.
29
30
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
31

32
33
    # Shuffle the dataset.
    random.shuffle(dataset)
34

35
    # Filter out sequences that are too long or too short
36
    filtered_dataset: List[Tuple[str, int, int]] = []
37
38
39
40
41
42
43
44
45
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
46
        prompt_len = len(prompt_token_ids)
47
48
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
49
50
51
52
53
54
55
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
56

57
    return filtered_dataset
58
59


Woosuk Kwon's avatar
Woosuk Kwon committed
60
def run_vllm(
61
62
    requests: List[Tuple[str, int, int]],
    model: str,
63
    tokenizer: str,
64
    quantization: Optional[str],
65
66
67
68
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
69
    trust_remote_code: bool,
70
    dtype: str,
71
72
    max_model_len: Optional[int],
    enforce_eager: bool,
73
    kv_cache_dtype: str,
74
    quantization_param_path: Optional[str],
75
    device: str,
76
    enable_prefix_caching: bool,
77
    gpu_memory_utilization: float = 0.9,
78
    download_dir: Optional[str] = None,
79
) -> float:
80
    from vllm import LLM, SamplingParams
81
82
83
84
85
86
87
88
    llm = LLM(model=model,
              tokenizer=tokenizer,
              quantization=quantization,
              tensor_parallel_size=tensor_parallel_size,
              seed=seed,
              trust_remote_code=trust_remote_code,
              dtype=dtype,
              max_model_len=max_model_len,
89
              gpu_memory_utilization=gpu_memory_utilization,
90
91
              enforce_eager=enforce_eager,
              kv_cache_dtype=kv_cache_dtype,
92
              quantization_param_path=quantization_param_path,
93
              device=device,
94
95
              enable_prefix_caching=enable_prefix_caching,
              download_dir=download_dir)
96

Zhuohan Li's avatar
Zhuohan Li committed
97
    # Add the requests to the engine.
98
    for prompt, _, output_len in requests:
99
        sampling_params = SamplingParams(
100
101
            n=n,
            temperature=0.0 if use_beam_search else 1.0,
102
            top_p=1.0,
103
            use_beam_search=use_beam_search,
104
105
106
107
108
            ignore_eos=True,
            max_tokens=output_len,
        )
        # FIXME(woosuk): Do not use internal method.
        llm._add_request(
109
            prompt=prompt,
110
            prompt_token_ids=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
            sampling_params=sampling_params,
112
113
        )

114
    start = time.perf_counter()
Zhuofan's avatar
Zhuofan committed
115
    # FIXME(woosuk): Do not use internal method.
Zhuohan Li's avatar
Zhuohan Li committed
116
    llm._run_engine(use_tqdm=True)
117
    end = time.perf_counter()
118
119
120
121
122
123
124
125
126
127
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
128
    trust_remote_code: bool,
129
130
) -> float:
    assert not use_beam_search
131
132
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
133
134
135
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
136
137
138
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
139
    start = time.perf_counter()
140
141
142
143
144
145
146
147
148
149
150
151
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
152
153
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
154
155
156
157
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
158
159
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
177
    end = time.perf_counter()
178
179
180
    return end - start


181
182
183
184
185
186
def run_mii(
    requests: List[Tuple[str, int, int]],
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
187
188
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
189
190
191
    prompts = [prompt for prompt, _, _ in requests]

    start = time.perf_counter()
192
    llm.generate(prompts, max_new_tokens=output_len)
193
    end = time.perf_counter()
194
195
    client = client(model)
    client.terminate_server()
196
197
198
    return end - start


199
200
201
202
203
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
204
205
206
207
208
209
210
211
212
213
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
    if args.dataset is None:
        # Synthesize a prompt with the given input length.
        prompt = "hi" * (args.input_len - 1)
        requests = [(prompt, args.input_len, args.output_len)
                    for _ in range(args.num_prompts)]
    else:
        requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                                   args.output_len)
214

Woosuk Kwon's avatar
Woosuk Kwon committed
215
    if args.backend == "vllm":
216
217
218
219
220
        elapsed_time = run_vllm(requests, args.model, args.tokenizer,
                                args.quantization, args.tensor_parallel_size,
                                args.seed, args.n, args.use_beam_search,
                                args.trust_remote_code, args.dtype,
                                args.max_model_len, args.enforce_eager,
221
222
                                args.kv_cache_dtype,
                                args.quantization_param_path, args.device,
223
224
                                args.enable_prefix_caching,
                                args.gpu_memory_utilization, args.download_dir)
225
226
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
227
228
229
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
230
231
232
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
233
234
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
235
236
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
239
240
241
242


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
243
244
    parser.add_argument("--backend",
                        type=str,
245
                        choices=["vllm", "hf", "mii"],
Woosuk Kwon's avatar
Woosuk Kwon committed
246
                        default="vllm")
247
248
    parser.add_argument("--dataset",
                        type=str,
249
                        default=None,
250
                        help="Path to the dataset.")
251
252
253
254
255
256
257
258
259
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
260
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
261
    parser.add_argument("--tokenizer", type=str, default=None)
262
263
    parser.add_argument('--quantization',
                        '-q',
CHU Tianxiang's avatar
CHU Tianxiang committed
264
                        choices=['awq', 'gptq', 'squeezellm', None],
265
                        default=None)
266
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
267
268
269
    parser.add_argument("--n",
                        type=int,
                        default=1,
270
271
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
272
273
274
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
275
276
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
277
278
279
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
280
                        help="Maximum batch size for HF backend.")
281
282
283
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
284
285
286
287
288
289
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
290
291
292
293
294
295
296
297
298
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
299
300
301
302
303
304
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
305
306
307
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
308
309
310
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
311
        choices=["auto", "fp8"],
312
313
        default="auto",
        help=
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        'Data type for kv cache storage. If "auto", will use model data type. '
        'FP8_E5M2 (without scaling) is only supported on cuda version greater '
        'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
        'common inference criteria.')
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
328
329
330
331
332
333
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda"],
        help='device type for vLLM execution, supporting CUDA only currently.')
334
335
336
337
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
        help="enable automatic prefix caching for vLLM backend.")
338
339
340
341
342
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
343
    args = parser.parse_args()
344
345
346
347
348
349
350
    if args.tokenizer is None:
        args.tokenizer = args.model
    if args.dataset is None:
        assert args.input_len is not None
        assert args.output_len is not None
    else:
        assert args.input_len is None
351

Woosuk Kwon's avatar
Woosuk Kwon committed
352
    if args.backend == "vllm":
353
354
355
356
357
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
358
359
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    elif args.backend == "mii":
        if args.dtype != "auto":
            raise ValueError("dtype must be auto for MII backend.")
        if args.n != 1:
            raise ValueError("n must be 1 for MII backend.")
        if args.use_beam_search:
            raise ValueError("Beam search is not supported for MII backend.")
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
        if args.tokenizer != args.model:
            raise ValueError("Tokenizer must be the same as the model for MII "
                             "backend.")
374
    main(args)