benchmark_paged_attention.py 12.1 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

zhuwenwen's avatar
zhuwenwen committed
4
5
import random
import time
6
from typing import Optional
zhuwenwen's avatar
zhuwenwen committed
7
8
9
10

import torch

from vllm import _custom_ops as ops
11
12
from vllm.logger import init_logger
from vllm.platforms import current_platform
laibao's avatar
laibao committed
13
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
14

15
16
17
18
19
20
21
22
23
from vllm.utils import (
    STR_DTYPE_TO_TORCH_DTYPE,
    FlexibleArgumentParser,
    create_kv_caches_with_random,
)

logger = init_logger(__name__)

NUM_BLOCKS = 128 * 1024
zhuwenwen's avatar
zhuwenwen committed
24
PARTITION_SIZE = 512
25
PARTITION_SIZE_ROCM = 256
zhuwenwen's avatar
zhuwenwen committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
    seq_len: int,
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
    device: str = "cuda",
    kv_cache_dtype: Optional[str] = None,
) -> None:
44
    current_platform.seed_everything(seed)
zhuwenwen's avatar
zhuwenwen committed
45
46

    scale = float(1.0 / (head_size**0.5))
47
48
49
    query = torch.empty(
        num_seqs, num_query_heads, head_size, dtype=dtype, device=device
    )
zhuwenwen's avatar
zhuwenwen committed
50
51
52
53
54
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
55
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
zhuwenwen's avatar
zhuwenwen committed
56
57
58
59
60
61
62

    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)

    # Create the block tables.
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
63
    block_tables_lst: list[list[int]] = []
zhuwenwen's avatar
zhuwenwen committed
64
65
    for _ in range(num_seqs):
        block_table = [
66
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
zhuwenwen's avatar
zhuwenwen committed
67
        ]
laibao's avatar
laibao committed
68
69
        block_tables_lst.append(block_table)

70
    block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
zhuwenwen's avatar
zhuwenwen committed
71
72

    # Create the KV cache.
73
74
75
76
77
78
79
80
81
82
    key_caches, value_caches = create_kv_caches_with_random(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        device=device,
    )
zhuwenwen's avatar
zhuwenwen committed
83
84
85
86
87
    key_cache, value_cache = key_caches[0], value_caches[0]

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
88
89
90
91
92
93
94
        if current_platform.is_rocm():
            global PARTITION_SIZE
            if not args.custom_paged_attn and not current_platform.is_navi():
                PARTITION_SIZE = 1024
            else:
                PARTITION_SIZE = PARTITION_SIZE_ROCM
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
zhuwenwen's avatar
zhuwenwen committed
95
96
97
98
99
100
101
102
103
104
105
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
106
107
108
109
        
    if version == "v12":
        sliding_window = ((-1, -1))
        logits_soft_cap = 0.0
zhuwenwen's avatar
zhuwenwen committed
110
111
112
113
114
115
116
117

    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

        # Using default kv_scale
118
        k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
zhuwenwen's avatar
zhuwenwen committed
119
120
121

        for _ in range(num_iters):
            if version == "v1":
122
123
                if args.gc_paged_attn:
                    if args.tc_paged_attn:
laibao's avatar
laibao committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                        ops.paged_attention_v1_opt_tc(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                    else:
                        ops.paged_attention_v1_opt(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                else:
                    ops.paged_attention_v1(
zhuwenwen's avatar
zhuwenwen committed
159
160
161
162
163
164
165
166
167
168
169
170
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
laibao's avatar
laibao committed
171
172
                    k_scale,
                    v_scale,
zhuwenwen's avatar
zhuwenwen committed
173
174
                )
            elif version == "v2":
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
                if not args.custom_paged_attn:   
                    if args.gc_paged_attn:     
                        if args.tc_paged_attn:
                            ops.paged_attention_v1_opt_tc(
                                output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
                        else:
                            ops.paged_attention_v2_opt(
                                output,
                                exp_sums,
                                max_logits,
                                tmp_output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
laibao's avatar
laibao committed
214
                    ops.paged_attention_v2(
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        None,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
            elif version == "v12":
                from flash_attn import vllm_flash_attn_with_kvcache
                vllm_flash_attn_with_kvcache(
                    q=query.unsqueeze(1),  
                    k_cache=key_cache,  
                    v_cache=value_cache,  
                    cache_seqlens=seq_lens,  
                    block_table=block_tables, 
                    softmax_scale=scale,
                    causal=True,
                    window_size=sliding_window,  
                    softcap=logits_soft_cap,
                    alibi_slopes=alibi_slopes,
                    return_softmax_lse=False,
                    k_scale=k_scale,  
                    v_scale=v_scale, 
                    kv_cache_dtype=kv_cache_dtype,  
                ).squeeze(1) 
zhuwenwen's avatar
zhuwenwen committed
272
273
274
275
276
277
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
278
            torch.cuda.cudart().cudaProfilerStop()
zhuwenwen's avatar
zhuwenwen committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
    run_benchmark = run_cuda_benchmark
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


294
295
296
297
298
299
300
301
if __name__ == "__main__":
    logger.warning(
        "This script benchmarks the paged attention kernel. "
        "By default this is no longer used in vLLM inference."
    )

    parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
    parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12")
zhuwenwen's avatar
zhuwenwen committed
302
    parser.add_argument("--batch-size", type=int, default=8)
laibao's avatar
laibao committed
303
    parser.add_argument("--seq-len", type=int, default=4096)
zhuwenwen's avatar
zhuwenwen committed
304
305
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
306
307
308
309
310
311
312
    parser.add_argument(
        "--head-size",
        type=int,
        choices=[64, 80, 96, 112, 120, 128, 192, 256],
        default=128,
    )
    parser.add_argument("--block-size", type=int, choices=[16, 32, 64], default=64)
zhuwenwen's avatar
zhuwenwen committed
313
    parser.add_argument("--use-alibi", action="store_true")
314
315
316
    parser.add_argument(
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
    )
zhuwenwen's avatar
zhuwenwen committed
317
318
319
320
321
322
323
324
325
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
        default="auto",
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
laibao's avatar
laibao committed
326
        "ROCm (hcu) supports fp8 (=fp8_e4m3)")
327
328
329
330
331
332
333
334
335
    parser.add_argument(
        "--gc-paged-attn", action="store_true", help="Use gc paged attention"
        )
    parser.add_argument(
        "--tc-paged-attn", action="store_true", help="Use tc paged attention"
        )
    parser.add_argument(
        "--custom-paged-attn", action="store_true", help="Use custom paged attention"
    )
zhuwenwen's avatar
zhuwenwen committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
        seq_len=args.seq_len,
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
        seed=args.seed,
        do_profile=args.profile,
        kv_cache_dtype=args.kv_cache_dtype,
    )