benchmark_paged_attention.py 10.2 KB
Newer Older
1
2
import random
import time
3
from typing import List, Optional
4
5
6

import torch

7
from vllm import _custom_ops as ops
8
from vllm.platforms import current_platform
9
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
10
                        create_kv_caches_with_random)
zhuwenwen's avatar
zhuwenwen committed
11
import vllm.envs as envs
12
13
14
15
16
17
18
19
20
21


NUM_BLOCKS = 1024
PARTITION_SIZE = 512


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
22
    seq_len: int,
23
24
25
26
27
28
29
30
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
31
    device: str = "cuda",
32
    kv_cache_dtype: Optional[str] = None,
33
) -> None:
34
    current_platform.seed_everything(seed)
35
36
37
38
39
40

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
41
                        device=device)
42
43
44
45
46
47
48
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
49
                                   device=device)
50

51
52
53
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
54
55

    # Create the block tables.
56
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
57
    block_tables_lst: List[List[int]] = []
58
59
60
61
62
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
63
64
65
66
67
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst,
                                dtype=torch.int,
                                device=device)
68
69

    # Create the KV cache.
70
71
72
73
74
75
76
77
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
78
    key_cache, value_cache = key_caches[0], value_caches[0]
79
80
81
82

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
83
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
84
85
86
87
88
89
90
91
92
93
94
95
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

96
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
97
98
99
100
101
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

102
        # Using default kv_scale
103
        k_scale = v_scale = 1.0
104

105
106
        for _ in range(num_iters):
            if version == "v1":
zhuwenwen's avatar
zhuwenwen committed
107
                if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
                    if envs.VLLM_USE_TC_PAGED_ATTN:
                        ops.paged_attention_v1_opt_tc(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                    else:
                        ops.paged_attention_v1_opt(
                            output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
zhuwenwen's avatar
zhuwenwen committed
142
143
                else:
                    ops.paged_attention_v1(
144
145
146
147
                    output,
                    query,
                    key_cache,
                    value_cache,
148
                    num_kv_heads,
149
150
                    scale,
                    block_tables,
151
                    seq_lens,
152
                    block_size,
153
                    max_seq_len,
154
                    alibi_slopes,
155
                    kv_cache_dtype,
156
157
                    k_scale,
                    v_scale,
158
159
                )
            elif version == "v2":
zhuwenwen's avatar
zhuwenwen committed
160
                if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
                    if envs.VLLM_USE_TC_PAGED_ATTN:
                        ops.paged_attention_v2_opt_tc(
                            output,
                            exp_sums,
                            max_logits,
                            tmp_output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
                    else:
                        ops.paged_attention_v2_opt(
                            output,
                            exp_sums,
                            max_logits,
                            tmp_output,
                            query,
                            key_cache,
                            value_cache,
                            num_kv_heads,
                            scale,
                            block_tables,
                            seq_lens,
                            block_size,
                            max_seq_len,
                            alibi_slopes,
                            kv_cache_dtype,
                            k_scale,
                            v_scale,
                        )
zhuwenwen's avatar
zhuwenwen committed
201
                else:
zhuwenwen's avatar
zhuwenwen committed
202
                    ops.paged_attention_v2(
203
204
205
206
207
208
209
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
210
                    num_kv_heads,
211
212
                    scale,
                    block_tables,
213
                    seq_lens,
214
                    block_size,
215
                    max_seq_len,
216
                    alibi_slopes,
217
                    kv_cache_dtype,
218
219
                    k_scale,
                    v_scale,
220
221
222
223
224
225
226
227
228
229
230
231
                )
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
232
    run_benchmark = run_cuda_benchmark
233
234
235
236
237
238
239
240
241
242
243
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
244
    parser = FlexibleArgumentParser(
245
246
247
248
249
250
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
251
    parser.add_argument("--seq-len", type=int, default=4096)
252
253
254
255
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
Joe's avatar
Joe committed
256
                        choices=[64, 80, 96, 112, 120, 128, 192, 256],
257
258
259
260
261
262
263
264
265
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
266
267
268
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
269
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
270
        default="auto",
271
272
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
zhuwenwen's avatar
zhuwenwen committed
273
        "ROCm (hcu) supports fp8 (=fp8_e4m3)")
274
275
276
277
278
279
280
281
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
282
        seq_len=args.seq_len,
283
284
285
286
287
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
288
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
289
290
        seed=args.seed,
        do_profile=args.profile,
291
        kv_cache_dtype=args.kv_cache_dtype,
292
    )