benchmark_paged_attention.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import random
import time
6
from typing import Optional
7
8
9

import torch

10
from vllm import _custom_ops as ops
11
from vllm.logger import init_logger
12
from vllm.platforms import current_platform
zhuwenwen's avatar
zhuwenwen committed
13
import vllm.envs as envs
14

15
16
17
18
19
from vllm.utils import (
    STR_DTYPE_TO_TORCH_DTYPE,
    FlexibleArgumentParser,
    create_kv_caches_with_random,
)
20

21
22
logger = init_logger(__name__)

23
NUM_BLOCKS = 128 * 1024
24
PARTITION_SIZE = 512
25
PARTITION_SIZE_ROCM = 256
26
27
28
29
30
31


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
32
    seq_len: int,
33
34
35
36
37
38
39
40
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
41
    device: str = "cuda",
42
    kv_cache_dtype: Optional[str] = None,
43
) -> None:
44
    current_platform.seed_everything(seed)
45
46

    scale = float(1.0 / (head_size**0.5))
47
48
49
    query = torch.empty(
        num_seqs, num_query_heads, head_size, dtype=dtype, device=device
    )
50
51
52
53
54
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
55
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
56

57
58
59
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
60
61

    # Create the block tables.
62
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
63
    block_tables_lst: list[list[int]] = []
64
65
    for _ in range(num_seqs):
        block_table = [
66
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
67
        ]
68
69
        block_tables_lst.append(block_table)

70
    block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
71
72

    # Create the KV cache.
73
74
75
76
77
78
79
80
81
82
    key_caches, value_caches = create_kv_caches_with_random(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        device=device,
    )
83
    key_cache, value_cache = key_caches[0], value_caches[0]
84
85
86
87

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
88
89
        if current_platform.is_rocm():
            global PARTITION_SIZE
90
            if not args.custom_paged_attn and not current_platform.is_navi():
91
92
93
                PARTITION_SIZE = 1024
            else:
                PARTITION_SIZE = PARTITION_SIZE_ROCM
94
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
95
96
97
98
99
100
101
102
103
104
105
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
zhuwenwen's avatar
zhuwenwen committed
106
107
108
109
        
    if version == "v12":
        sliding_window = ((-1, -1))
        logits_soft_cap = 0.0
110

111
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
112
113
114
115
116
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

117
        # Using default kv_scale
118
        k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
119

120
121
        for _ in range(num_iters):
            if version == "v1":
zhuwenwen's avatar
zhuwenwen committed
122
123
                if args.gc_paged_attn:
                    if args.tc_paged_attn:
zhuwenwen's avatar
zhuwenwen committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                        ops.paged_attention_v1_opt_tc(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                    else:
                        ops.paged_attention_v1_opt(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
zhuwenwen's avatar
zhuwenwen committed
157
158
                else:
                    ops.paged_attention_v1(
159
160
161
162
                    output,
                    query,
                    key_cache,
                    value_cache,
163
                    num_kv_heads,
164
165
                    scale,
                    block_tables,
166
                    seq_lens,
167
                    block_size,
168
                    max_seq_len,
169
                    alibi_slopes,
170
                    kv_cache_dtype,
171
172
                    k_scale,
                    v_scale,
173
174
                )
            elif version == "v2":
zhuwenwen's avatar
zhuwenwen committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
                if not args.custom_paged_attn:   
                    if args.gc_paged_attn:     
                        if args.tc_paged_attn:
                            ops.paged_attention_v1_opt_tc(
                                output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
                        else:
                            ops.paged_attention_v2_opt(
                                output,
                                exp_sums,
                                max_logits,
                                tmp_output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
zhuwenwen's avatar
zhuwenwen committed
214
                    ops.paged_attention_v2(
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
246
                        None,
247
248
249
250
251
252
253
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
zhuwenwen's avatar
zhuwenwen committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
            elif version == "v12":
                from flash_attn import vllm_flash_attn_with_kvcache
                vllm_flash_attn_with_kvcache(
                    q=query.unsqueeze(1),  
                    k_cache=key_cache,  
                    v_cache=value_cache,  
                    cache_seqlens=seq_lens,  
                    block_table=block_tables, 
                    softmax_scale=scale,
                    causal=True,
                    window_size=sliding_window,  
                    softcap=logits_soft_cap,
                    alibi_slopes=alibi_slopes,
                    return_softmax_lse=False,
                    k_scale=k_scale,  
                    v_scale=v_scale, 
                    kv_cache_dtype=kv_cache_dtype,  
                ).squeeze(1) 
272
273
274
275
276
277
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
278
            torch.cuda.cudart().cudaProfilerStop()
279
280
281
282
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
283
    run_benchmark = run_cuda_benchmark
284
285
286
287
288
289
290
291
292
293
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


294
295
296
297
298
if __name__ == "__main__":
    logger.warning(
        "This script benchmarks the paged attention kernel. "
        "By default this is no longer used in vLLM inference."
    )
299

300
    parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
zhuwenwen's avatar
zhuwenwen committed
301
    parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12")
302
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
303
    parser.add_argument("--seq-len", type=int, default=4096)
304
305
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
306
307
308
309
310
311
    parser.add_argument(
        "--head-size",
        type=int,
        choices=[64, 80, 96, 112, 120, 128, 192, 256],
        default=128,
    )
zhuwenwen's avatar
zhuwenwen committed
312
    parser.add_argument("--block-size", type=int, choices=[16, 32, 64], default=64)
313
    parser.add_argument("--use-alibi", action="store_true")
314
315
316
    parser.add_argument(
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
    )
317
318
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
319
320
321
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
322
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
323
        default="auto",
324
325
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
zhuwenwen's avatar
zhuwenwen committed
326
        "ROCm (hcu) supports fp8 (=fp8_e4m3)")
zhuwenwen's avatar
zhuwenwen committed
327
328
329
330
331
332
    parser.add_argument(
        "--gc-paged-attn", action="store_true", help="Use gc paged attention"
        )
    parser.add_argument(
        "--tc-paged-attn", action="store_true", help="Use tc paged attention"
        )
333
334
335
    parser.add_argument(
        "--custom-paged-attn", action="store_true", help="Use custom paged attention"
    )
336
337
338
339
340
341
342
343
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
344
        seq_len=args.seq_len,
345
346
347
348
349
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
350
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
351
352
        seed=args.seed,
        do_profile=args.profile,
353
        kv_cache_dtype=args.kv_cache_dtype,
354
    )