benchmark_paged_attention.py 8.74 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import random
import time
5
from typing import Optional
6
7
8

import torch

9
from vllm import _custom_ops as ops
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
13
                        create_kv_caches_with_random)
14

15
16
logger = init_logger(__name__)

17
NUM_BLOCKS = 128 * 1024
18
PARTITION_SIZE = 512
19
PARTITION_SIZE_ROCM = 256
20
21
22
23
24
25


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
26
    seq_len: int,
27
28
29
30
31
32
33
34
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
35
    device: str = "cuda",
36
    kv_cache_dtype: Optional[str] = None,
37
) -> None:
38
    current_platform.seed_everything(seed)
39
40
41
42
43
44

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
45
                        device=device)
46
47
48
49
50
51
52
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
53
                                   device=device)
54

55
56
57
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
58
59

    # Create the block tables.
60
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
61
    block_tables_lst: list[list[int]] = []
62
63
64
65
66
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
67
68
69
70
71
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst,
                                dtype=torch.int,
                                device=device)
72
73

    # Create the KV cache.
74
75
76
77
78
79
80
81
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
82
    key_cache, value_cache = key_caches[0], value_caches[0]
83
84
85
86

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
87
88
89
90
91
92
        if current_platform.is_rocm():
            global PARTITION_SIZE
            if not args.custom_paged_attn:
                PARTITION_SIZE = 1024
            else:
                PARTITION_SIZE = PARTITION_SIZE_ROCM
93
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
94
95
96
97
98
99
100
101
102
103
104
105
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

106
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
107
108
109
110
111
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

112
        # Using default kv_scale
113
114
115
        k_scale = v_scale = torch.tensor(1.0,
                                         dtype=torch.float32,
                                         device=device)
116

117
118
        for _ in range(num_iters):
            if version == "v1":
119
                ops.paged_attention_v1(
120
121
122
123
                    output,
                    query,
                    key_cache,
                    value_cache,
124
                    num_kv_heads,
125
126
                    scale,
                    block_tables,
127
                    seq_lens,
128
                    block_size,
129
                    max_seq_len,
130
                    alibi_slopes,
131
                    kv_cache_dtype,
132
133
                    k_scale,
                    v_scale,
134
135
                )
            elif version == "v2":
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                if not args.custom_paged_attn:
                    ops.paged_attention_v2(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
176
177
178
179
180
181
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
182
            torch.cuda.cudart().cudaProfilerStop()
183
184
185
186
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
187
    run_benchmark = run_cuda_benchmark
188
189
190
191
192
193
194
195
196
197
198
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
199
200
201
    logger.warning("This script benchmarks the paged attention kernel. "
                   "By default this is no longer used in vLLM inference.")

202
    parser = FlexibleArgumentParser(
203
204
205
206
207
208
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
209
    parser.add_argument("--seq-len", type=int, default=4096)
210
211
212
213
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
Joe's avatar
Joe committed
214
                        choices=[64, 80, 96, 112, 120, 128, 192, 256],
215
216
217
218
219
220
221
222
223
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
224
225
226
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
227
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
228
        default="auto",
229
230
231
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
        "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
232
233
234
    parser.add_argument("--custom-paged-attn",
                        action="store_true",
                        help="Use custom paged attention")
235
236
237
238
239
240
241
242
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
243
        seq_len=args.seq_len,
244
245
246
247
248
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
249
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
250
251
        seed=args.seed,
        do_profile=args.profile,
252
        kv_cache_dtype=args.kv_cache_dtype,
253
    )