benchmark_paged_attention.py 8.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import random
import time
6
from typing import Optional
7
8
9

import torch

10
from vllm import _custom_ops as ops
11
from vllm.logger import init_logger
12
from vllm.platforms import current_platform
zhuwenwen's avatar
zhuwenwen committed
13
import vllm.envs as envs
14

15
16
17
18
19
from vllm.utils import (
    STR_DTYPE_TO_TORCH_DTYPE,
    FlexibleArgumentParser,
    create_kv_caches_with_random,
)
20

21
22
logger = init_logger(__name__)

23
NUM_BLOCKS = 128 * 1024
24
PARTITION_SIZE = 512
25
PARTITION_SIZE_ROCM = 256
26
27
28
29
30
31


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
32
    seq_len: int,
33
34
35
36
37
38
39
40
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
41
    device: str = "cuda",
42
    kv_cache_dtype: Optional[str] = None,
43
) -> None:
44
    current_platform.seed_everything(seed)
45
46

    scale = float(1.0 / (head_size**0.5))
47
48
49
    query = torch.empty(
        num_seqs, num_query_heads, head_size, dtype=dtype, device=device
    )
50
51
52
53
54
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
55
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
56

57
58
59
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
60
61

    # Create the block tables.
62
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
63
    block_tables_lst: list[list[int]] = []
64
65
    for _ in range(num_seqs):
        block_table = [
66
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
67
        ]
68
69
        block_tables_lst.append(block_table)

70
    block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
71
72

    # Create the KV cache.
73
74
75
76
77
78
79
80
81
82
    key_caches, value_caches = create_kv_caches_with_random(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        device=device,
    )
83
    key_cache, value_cache = key_caches[0], value_caches[0]
84
85
86
87

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
88
89
        if current_platform.is_rocm():
            global PARTITION_SIZE
90
            if not args.custom_paged_attn and not current_platform.is_navi():
91
92
93
                PARTITION_SIZE = 1024
            else:
                PARTITION_SIZE = PARTITION_SIZE_ROCM
94
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
95
96
97
98
99
100
101
102
103
104
105
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
zhuwenwen's avatar
zhuwenwen committed
106
107
108
109
        
    if version == "v12":
        sliding_window = ((-1, -1))
        logits_soft_cap = 0.0
110

111
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
112
113
114
115
116
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

117
        # Using default kv_scale
118
        k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
119

120
121
        for _ in range(num_iters):
            if version == "v1":
zhuwenwen's avatar
zhuwenwen committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                ops.paged_attention_v1(
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )
138
            elif version == "v2":
zhuwenwen's avatar
zhuwenwen committed
139
                if not args.custom_paged_attn:   
zhuwenwen's avatar
zhuwenwen committed
140
                    ops.paged_attention_v2(
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
172
                        None,
173
174
175
176
177
178
179
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
180
181
182
183
184
185
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
186
            torch.cuda.cudart().cudaProfilerStop()
187
188
189
190
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
191
    run_benchmark = run_cuda_benchmark
192
193
194
195
196
197
198
199
200
201
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


202
203
204
205
206
if __name__ == "__main__":
    logger.warning(
        "This script benchmarks the paged attention kernel. "
        "By default this is no longer used in vLLM inference."
    )
207

208
    parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
zhuwenwen's avatar
zhuwenwen committed
209
    parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
210
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
211
    parser.add_argument("--seq-len", type=int, default=4096)
212
213
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
214
215
216
217
218
219
    parser.add_argument(
        "--head-size",
        type=int,
        choices=[64, 80, 96, 112, 120, 128, 192, 256],
        default=128,
    )
zhuwenwen's avatar
zhuwenwen committed
220
    parser.add_argument("--block-size", type=int, choices=[16, 32, 64], default=64)
221
    parser.add_argument("--use-alibi", action="store_true")
222
223
224
    parser.add_argument(
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
    )
225
226
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
227
228
229
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
230
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
231
        default="auto",
232
233
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
zhuwenwen's avatar
zhuwenwen committed
234
        "ROCm (hcu) supports fp8 (=fp8_e4m3)")
235
236
237
    parser.add_argument(
        "--custom-paged-attn", action="store_true", help="Use custom paged attention"
    )
238
239
240
241
242
243
244
245
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
246
        seq_len=args.seq_len,
247
248
249
250
251
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
252
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
253
254
        seed=args.seed,
        do_profile=args.profile,
255
        kv_cache_dtype=args.kv_cache_dtype,
256
    )