benchmark_throughput.py 7.59 KB
Newer Older
1
"""Benchmark offline inference throughput."""
2
3
4
5
6
7
import argparse
import json
import random
import time
from typing import List, Tuple

8
import torch
9
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
10
11
from tqdm import tqdm

Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm import LLM, SamplingParams
13
from vllm.transformers_utils.tokenizer import get_tokenizer
14
15
16
17
18
19


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
20
) -> List[Tuple[str, int, int]]:
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [
        data for data in dataset
        if len(data["conversations"]) >= 2
    ]
    # Only keep the first two turns of each conversation.
    dataset = [
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
        for data in dataset
    ]

    # Tokenize the prompts and completions.
    prompts = [prompt for prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
43
44
45
46
47
48
49
50
51
52
53
54
55
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    # Filter out too long sequences.
    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))
56
57

    # Sample the requests.
58
    sampled_requests = random.sample(filtered_dataset, num_requests)
59
60
61
    return sampled_requests


Woosuk Kwon's avatar
Woosuk Kwon committed
62
def run_vllm(
63
64
    requests: List[Tuple[str, int, int]],
    model: str,
65
    tokenizer: str,
66
67
68
69
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
70
    trust_remote_code: bool,
71
) -> float:
72
    llm = LLM(
73
        model=model,
74
        tokenizer=tokenizer,
75
76
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
77
        trust_remote_code=trust_remote_code
78
79
    )

Zhuohan Li's avatar
Zhuohan Li committed
80
    # Add the requests to the engine.
81
    for prompt, _, output_len in requests:
82
        sampling_params = SamplingParams(
83
84
            n=n,
            temperature=0.0 if use_beam_search else 1.0,
85
            top_p=1.0,
86
            use_beam_search=use_beam_search,
87
88
89
90
91
            ignore_eos=True,
            max_tokens=output_len,
        )
        # FIXME(woosuk): Do not use internal method.
        llm._add_request(
92
            prompt=prompt,
93
            prompt_token_ids=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
            sampling_params=sampling_params,
95
96
97
98
        )

    start = time.time()
    # FIXME(woosuk): Do use internal method.
Zhuohan Li's avatar
Zhuohan Li committed
99
    llm._run_engine(use_tqdm=True)
100
    end = time.time()
101
102
103
104
105
106
107
108
109
110
    return end - start


def run_hf(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    use_beam_search: bool,
    max_batch_size: int,
111
    trust_remote_code: bool,
112
113
) -> float:
    assert not use_beam_search
114
    llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
115
116
117
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
    start = time.time()
    batch: List[str] = []
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
            if (max(max_prompt_len, next_prompt_len) + max(
                max_output_len, next_output_len)) <= 2048:
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
            do_sample=not use_beam_search,
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
        # Include the decoding time.
        tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
    end = time.time()
    return end - start


def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
167
    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
168
169
    requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
    if args.backend == "vllm":
        elapsed_time = run_vllm(
172
            requests, args.model, args.tokenizer, args.tensor_parallel_size,
173
            args.seed, args.n, args.use_beam_search, args.trust_remote_code)
174
175
176
177
178
179
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
                              args.use_beam_search, args.hf_max_batch_size)
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
180
    total_num_tokens = sum(
181
182
        prompt_len + output_len
        for _, prompt_len, output_len in requests
183
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
186
187
188
189


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
    parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
                        default="vllm")
192
193
194
    parser.add_argument("--dataset", type=str, required=True,
                        help="Path to the dataset.")
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
195
    parser.add_argument("--tokenizer", type=str, default=None)
196
197
198
199
200
201
202
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
    parser.add_argument("--n", type=int, default=1,
                        help="Number of generated sequences per prompt.")
    parser.add_argument("--use-beam-search", action="store_true")
    parser.add_argument("--num-prompts", type=int, default=1000,
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
203
204
    parser.add_argument("--hf-max-batch-size", type=int, default=None,
                        help="Maximum batch size for HF backend.")
205
206
207
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
208
    args = parser.parse_args()
209

Woosuk Kwon's avatar
Woosuk Kwon committed
210
    if args.backend == "vllm":
211
212
213
214
215
        if args.hf_max_batch_size is not None:
            raise ValueError("HF max batch size is only for HF backend.")
    elif args.backend == "hf":
        if args.hf_max_batch_size is None:
            raise ValueError("HF max batch size is required for HF backend.")
216
217
    if args.tokenizer is None:
        args.tokenizer = args.model
218

219
    main(args)