"vscode:/vscode.git/clone" did not exist on "542a4059b2bb0f790e82822c8b9cbcf8cde91adb"
benchmark_paged_attention.py 7.31 KB
Newer Older
1
2
import random
import time
3
from typing import List, Optional
4
5
6

import torch

7
from vllm import _custom_ops as ops
8
from vllm.platforms import current_platform
9
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
10
                        create_kv_caches_with_random)
11
12
13
14
15
16
17
18
19

NUM_BLOCKS = 1024
PARTITION_SIZE = 512


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
20
    seq_len: int,
21
22
23
24
25
26
27
28
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
29
    device: str = "cuda",
30
    kv_cache_dtype: Optional[str] = None,
31
) -> None:
32
    current_platform.seed_everything(seed)
33
34
35
36
37
38

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
39
                        device=device)
40
41
42
43
44
45
46
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
47
                                   device=device)
48

49
50
51
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
52
53

    # Create the block tables.
54
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
55
    block_tables_lst: List[List[int]] = []
56
57
58
59
60
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
61
62
63
64
65
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst,
                                dtype=torch.int,
                                device=device)
66
67

    # Create the KV cache.
68
69
70
71
72
73
74
75
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
76
    key_cache, value_cache = key_caches[0], value_caches[0]
77
78
79
80

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
81
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
82
83
84
85
86
87
88
89
90
91
92
93
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

94
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
95
96
97
98
99
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

100
        # Using default kv_scale
101
102
103
        k_scale = v_scale = torch.tensor(1.0,
                                         dtype=torch.float32,
                                         device=device)
104

105
106
        for _ in range(num_iters):
            if version == "v1":
107
                ops.paged_attention_v1(
108
109
110
111
                    output,
                    query,
                    key_cache,
                    value_cache,
112
                    num_kv_heads,
113
114
                    scale,
                    block_tables,
115
                    seq_lens,
116
                    block_size,
117
                    max_seq_len,
118
                    alibi_slopes,
119
                    kv_cache_dtype,
120
121
                    k_scale,
                    v_scale,
122
123
                )
            elif version == "v2":
124
                ops.paged_attention_v2(
125
126
127
128
129
130
131
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
132
                    num_kv_heads,
133
134
                    scale,
                    block_tables,
135
                    seq_lens,
136
                    block_size,
137
                    max_seq_len,
138
                    alibi_slopes,
139
                    kv_cache_dtype,
140
141
                    k_scale,
                    v_scale,
142
143
144
145
146
147
148
149
150
151
152
153
                )
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
154
    run_benchmark = run_cuda_benchmark
155
156
157
158
159
160
161
162
163
164
165
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
166
    parser = FlexibleArgumentParser(
167
168
169
170
171
172
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
Allen.Dou's avatar
Allen.Dou committed
173
    parser.add_argument("--seq-len", type=int, default=4096)
174
175
176
177
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
Joe's avatar
Joe committed
178
                        choices=[64, 80, 96, 112, 120, 128, 192, 256],
179
180
181
182
183
184
185
186
187
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
188
189
190
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
191
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
192
        default="auto",
193
194
195
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
        "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
196
197
198
199
200
201
202
203
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
204
        seq_len=args.seq_len,
205
206
207
208
209
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
210
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
211
212
        seed=args.seed,
        do_profile=args.profile,
213
        kv_cache_dtype=args.kv_cache_dtype,
214
    )