benchmark_paged_attention.py 7.16 KB
Newer Older
1
2
3
import argparse
import random
import time
4
from typing import Optional
5
6
7

import torch

8
from vllm._C import ops
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

NUM_BLOCKS = 1024
PARTITION_SIZE = 512


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
    context_len: int,
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
28
    device: str = "cuda",
29
    kv_cache_dtype: Optional[str] = None,
30
31
32
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
33
34
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
35
36
37
38
39
40

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
41
                        device=device)
42
43
44
45
46
47
48
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
49
                                   device=device)
50
51
52

    context_lens = [context_len for _ in range(num_seqs)]
    max_context_len = max(context_lens)
53
    context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
54
55
56
57
58
59
60
61
62
63

    # Create the block tables.
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
64
    block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
65
66

    # Create the KV cache.
67
68
69
70
71
72
73
74
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
75
    key_cache, value_cache = key_caches[0], value_caches[0]
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
                          PARTITION_SIZE)
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

94
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
95
96
97
98
99
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

100
101
102
        # Using default kv_scale
        kv_scale = 1.0

103
104
        for _ in range(num_iters):
            if version == "v1":
105
                ops.paged_attention_v1(
106
107
108
109
                    output,
                    query,
                    key_cache,
                    value_cache,
110
                    num_kv_heads,
111
112
113
114
115
116
                    scale,
                    block_tables,
                    context_lens,
                    block_size,
                    max_context_len,
                    alibi_slopes,
117
                    kv_cache_dtype,
118
                    kv_scale,
119
120
                )
            elif version == "v2":
121
                ops.paged_attention_v2(
122
123
124
125
126
127
128
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
129
                    num_kv_heads,
130
131
132
133
134
135
                    scale,
                    block_tables,
                    context_lens,
                    block_size,
                    max_context_len,
                    alibi_slopes,
136
                    kv_cache_dtype,
137
                    kv_scale,
138
139
140
141
142
143
144
145
146
147
148
149
                )
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
150
    run_benchmark = run_cuda_benchmark
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--context-len", type=int, default=4096)
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
                        choices=[64, 80, 96, 112, 128, 256],
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
184
185
186
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
187
        choices=["auto", "fp8"],
188
189
        default="auto",
        help=
190
191
192
193
        'Data type for kv cache storage. If "auto", will use model data type. '
        'FP8_E5M2 (without scaling) is only supported on cuda version greater '
        'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
        'common inference criteria.')
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
        context_len=args.context_len,
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
208
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
209
210
        seed=args.seed,
        do_profile=args.profile,
211
        kv_cache_dtype=args.kv_cache_dtype,
212
    )