"examples/online_serving/kv_events.sh" did not exist on "77073c77bc2006eb80ea6d5128f076f5e6c6f54f"
benchmark_paged_attention.py 7.83 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import random
import time
5
from typing import Optional
6
7
8

import torch

9
from vllm import _custom_ops as ops
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
13
14
15
16
from vllm.utils import (
    STR_DTYPE_TO_TORCH_DTYPE,
    FlexibleArgumentParser,
    create_kv_caches_with_random,
)
17

18
19
logger = init_logger(__name__)

20
NUM_BLOCKS = 128 * 1024
21
PARTITION_SIZE = 512
22
PARTITION_SIZE_ROCM = 256
23
24
25
26
27
28


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
29
    seq_len: int,
30
31
32
33
34
35
36
37
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
38
    device: str = "cuda",
39
    kv_cache_dtype: Optional[str] = None,
40
) -> None:
41
    current_platform.seed_everything(seed)
42
43

    scale = float(1.0 / (head_size**0.5))
44
45
46
    query = torch.empty(
        num_seqs, num_query_heads, head_size, dtype=dtype, device=device
    )
47
48
49
50
51
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
52
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
53

54
55
56
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
57
58

    # Create the block tables.
59
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
60
    block_tables_lst: list[list[int]] = []
61
62
    for _ in range(num_seqs):
        block_table = [
63
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
64
        ]
65
66
        block_tables_lst.append(block_table)

67
    block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
68
69

    # Create the KV cache.
70
71
72
73
74
75
76
77
78
79
    key_caches, value_caches = create_kv_caches_with_random(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        device=device,
    )
80
    key_cache, value_cache = key_caches[0], value_caches[0]
81
82
83
84

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
85
86
        if current_platform.is_rocm():
            global PARTITION_SIZE
87
88
            PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
89
90
91
92
93
94
95
96
97
98
99
100
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

101
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
102
103
104
105
106
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

107
        # Using default kv_scale
108
        k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
109

110
111
        for _ in range(num_iters):
            if version == "v1":
112
                ops.paged_attention_v1(
113
114
115
116
                    output,
                    query,
                    key_cache,
                    value_cache,
117
                    num_kv_heads,
118
119
                    scale,
                    block_tables,
120
                    seq_lens,
121
                    block_size,
122
                    max_seq_len,
123
                    alibi_slopes,
124
                    kv_cache_dtype,
125
126
                    k_scale,
                    v_scale,
127
128
                )
            elif version == "v2":
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                if not args.custom_paged_attn:
                    ops.paged_attention_v2(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
                else:
                    ops.paged_attention_rocm(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                    )
169
170
171
172
173
174
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
175
            torch.cuda.cudart().cudaProfilerStop()
176
177
178
179
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
180
    run_benchmark = run_cuda_benchmark
181
182
183
184
185
186
187
188
189
190
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


191
192
193
194
195
if __name__ == "__main__":
    logger.warning(
        "This script benchmarks the paged attention kernel. "
        "By default this is no longer used in vLLM inference."
    )
196

197
198
    parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
    parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
199
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
200
    parser.add_argument("--seq-len", type=int, default=4096)
201
202
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
203
204
205
206
207
208
    parser.add_argument(
        "--head-size",
        type=int,
        choices=[64, 80, 96, 112, 120, 128, 192, 256],
        default=128,
    )
209
210
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
211
212
213
    parser.add_argument(
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
    )
214
215
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
216
217
218
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
219
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
220
        default="auto",
221
222
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
223
224
225
226
227
        "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
    )
    parser.add_argument(
        "--custom-paged-attn", action="store_true", help="Use custom paged attention"
    )
228
229
230
231
232
233
234
235
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
236
        seq_len=args.seq_len,
237
238
239
240
241
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
242
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
243
244
        seed=args.seed,
        do_profile=args.profile,
245
        kv_cache_dtype=args.kv_cache_dtype,
246
    )