benchmark_throughput.py 8.75 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
import argparse
import json
import random
import time
6
from typing import List, Optional, Tuple
7

8
import torch
9
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
10
11
from tqdm import tqdm

Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm import LLM, SamplingParams
13
from vllm.transformers_utils.tokenizer import get_tokenizer
14
15
16
17
18
19


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
20
) -> List[Tuple[str, int, int]]:
21
22
23
24
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
25
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
26
    # Only keep the first two turns of each conversation.
27
28
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
29
30
31
32
33
34
35
36
37

    # Tokenize the prompts and completions.
    prompts = [prompt for prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
38
39
40
41
42
43
44
45
46
47
48
49
50
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    # Filter out too long sequences.
    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
51
52

    # Sample the requests.
53
    sampled_requests = random.sample(filtered_dataset, num_requests)
54
55
56
    return sampled_requests


Woosuk Kwon's avatar
Woosuk Kwon committed
57
def run_vllm(
58
59
    requests: List[Tuple[str, int, int]],
    model: str,
60
    tokenizer: str,
61
    quantization: Optional[str],
62
63
64
65
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
66
    trust_remote_code: bool,
67
    dtype: str,
68
) -> float:
69
    llm = LLM(
70
        model=model,
71
        tokenizer=tokenizer,
72
        quantization=quantization,
73
74
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
75
        trust_remote_code=trust_remote_code,
76
        dtype=dtype,
77
78
    )

Zhuohan Li's avatar
Zhuohan Li committed
79
    # Add the requests to the engine.
80
    for prompt, _, output_len in requests:
81
        sampling_params = SamplingParams(
82
83
            n=n,
            temperature=0.0 if use_beam_search else 1.0,
84
            top_p=1.0,
85
            use_beam_search=use_beam_search,
86
87
88
89
90
            ignore_eos=True,
            max_tokens=output_len,
        )
        # FIXME(woosuk): Do not use internal method.
        llm._add_request(
91
            prompt=prompt,
92
            prompt_token_ids=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
            sampling_params=sampling_params,
94
95
96
97
        )

    start = time.time()
    # FIXME(woosuk): Do use internal method.
Zhuohan Li's avatar
Zhuohan Li committed
98
    llm._run_engine(use_tqdm=True)
99
    end = time.time()
100
101
102
103
104
105
106
107
108
109
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
110
    trust_remote_code: bool,
111
112
) -> float:
    assert not use_beam_search
113
114
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
115
116
117
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.time()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
134
135
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
136
137
138
139
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
140
141
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.time()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
168
169
    tokenizer = get_tokenizer(args.tokenizer,
                              trust_remote_code=args.trust_remote_code)
170
171
    requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

Woosuk Kwon's avatar
Woosuk Kwon committed
172
    if args.backend == "vllm":
173
174
175
        elapsed_time = run_vllm(requests, args.model, args.tokenizer,
                                args.quantization, args.tensor_parallel_size,
                                args.seed, args.n, args.use_beam_search,
176
                                args.trust_remote_code, args.dtype)
177
178
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
179
180
181
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size,
                              args.trust_remote_code)
182
183
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
184
185
    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
188
189
190
191


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
192
193
194
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm", "hf"],
Woosuk Kwon's avatar
Woosuk Kwon committed
195
                        default="vllm")
196
197
198
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
199
200
                        help="Path to the dataset.")
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
201
    parser.add_argument("--tokenizer", type=str, default=None)
202
203
204
205
    parser.add_argument('--quantization',
                        '-q',
                        choices=['awq', None],
                        default=None)
206
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
207
208
209
    parser.add_argument("--n",
                        type=int,
                        default=1,
210
211
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
212
213
214
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
215
216
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
217
218
219
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
220
                        help="Maximum batch size for HF backend.")
221
222
223
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
224
225
226
227
228
229
230
231
232
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
233
    args = parser.parse_args()
234

Woosuk Kwon's avatar
Woosuk Kwon committed
235
    if args.backend == "vllm":
236
237
238
239
240
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
241
242
        if args.quantization is not None:
            raise ValueError("Quantization is only for vLLM backend.")
243
244
    if args.tokenizer is None:
        args.tokenizer = args.model
245

246
    main(args)