benchmark_paged_attention.py 12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import random
import time
5
from typing import Optional
6
7
8

import torch

9
from vllm import _custom_ops as ops
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
13
                        create_kv_caches_with_random)
zhuwenwen's avatar
zhuwenwen committed
14
import vllm.envs as envs
15
16


17
18
logger = init_logger(__name__)

19
NUM_BLOCKS = 128 * 1024
20
PARTITION_SIZE = 512
21
PARTITION_SIZE_ROCM = 256
22
23
24
25
26
27


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
28
    seq_len: int,
29
30
31
32
33
34
35
36
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
37
    device: str = "cuda",
38
    kv_cache_dtype: Optional[str] = None,
39
) -> None:
40
    current_platform.seed_everything(seed)
41
42
43
44
45
46

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
47
                        device=device)
48
49
50
51
52
53
54
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
55
                                   device=device)
56

57
58
59
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
60
61

    # Create the block tables.
62
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
63
    block_tables_lst: list[list[int]] = []
64
65
66
67
68
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
69
70
71
72
73
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst,
                                dtype=torch.int,
                                device=device)
74
75

    # Create the KV cache.
76
77
78
79
80
81
82
83
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
84
    key_cache, value_cache = key_caches[0], value_caches[0]
85
86
87
88

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
89
90
91
92
93
94
        if current_platform.is_rocm():
            global PARTITION_SIZE
            if not args.custom_paged_attn:
                PARTITION_SIZE = 1024
            else:
                PARTITION_SIZE = PARTITION_SIZE_ROCM
95
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
96
97
98
99
100
101
102
103
104
105
106
107
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

108
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
109
110
111
112
113
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

114
        # Using default kv_scale
115
116
117
        k_scale = v_scale = torch.tensor(1.0,
                                         dtype=torch.float32,
                                         device=device)
118

119
120
        for _ in range(num_iters):
            if version == "v1":
zhuwenwen's avatar
zhuwenwen committed
121
122
                if args.gc_paged_attn:
                    if args.tc_paged_attn:
zhuwenwen's avatar
zhuwenwen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                        ops.paged_attention_v1_opt_tc(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                    else:
                        ops.paged_attention_v1_opt(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
zhuwenwen's avatar
zhuwenwen committed
156
157
                else:
                    ops.paged_attention_v1(
158
159
160
161
                    output,
                    query,
                    key_cache,
                    value_cache,
162
                    num_kv_heads,
163
164
                    scale,
                    block_tables,
165
                    seq_lens,
166
                    block_size,
167
                    max_seq_len,
168
                    alibi_slopes,
169
                    kv_cache_dtype,
170
171
                    k_scale,
                    v_scale,
172
173
                )
            elif version == "v2":
zhuwenwen's avatar
zhuwenwen committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                if not args.custom_paged_attn:   
                    if args.gc_paged_attn:     
                        if args.tc_paged_attn:
                            ops.paged_attention_v1_opt_tc(
                                output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
                        else:
                            ops.paged_attention_v2_opt(
                                output,
                                exp_sums,
                                max_logits,
                                tmp_output,
                                query,
                                key_cache,
                                value_cache,
                                num_kv_heads,
                                scale,
                                block_tables,
                                seq_lens,
                                block_size,
                                max_seq_len,
                                alibi_slopes,
                                kv_cache_dtype,
                                k_scale,
                                v_scale,
                            )
zhuwenwen's avatar
zhuwenwen committed
213
                    ops.paged_attention_v2(
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
252
253
254
255
256
257
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
258
            torch.cuda.cudart().cudaProfilerStop()
259
260
261
262
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
263
    run_benchmark = run_cuda_benchmark
264
265
266
267
268
269
270
271
272
273
274
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
275
276
277
    logger.warning("This script benchmarks the paged attention kernel. "
                   "By default this is no longer used in vLLM inference.")

278
    parser = FlexibleArgumentParser(
279
280
281
282
283
284
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
285
    parser.add_argument("--seq-len", type=int, default=4096)
286
287
288
289
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
Joe's avatar
Joe committed
290
                        choices=[64, 80, 96, 112, 120, 128, 192, 256],
291
292
293
294
295
296
297
298
299
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
300
301
302
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
303
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
304
        default="auto",
305
306
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
zhuwenwen's avatar
zhuwenwen committed
307
        "ROCm (hcu) supports fp8 (=fp8_e4m3)")
zhuwenwen's avatar
zhuwenwen committed
308
309
310
311
312
313
    parser.add_argument("--gc-paged-attn",     
                        action="store_true",
                        help="Use gc paged attention")
    parser.add_argument("--tc-paged-attn",
                        action="store_true",
                        help="Use tc paged attention")
314
315
316
    parser.add_argument("--custom-paged-attn",
                        action="store_true",
                        help="Use custom paged attention")
317
318
319
320
321
322
323
324
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
325
        seq_len=args.seq_len,
326
327
328
329
330
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
331
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
332
333
        seed=args.seed,
        do_profile=args.profile,
334
        kv_cache_dtype=args.kv_cache_dtype,
335
    )