"vllm/tool_parsers/minimax_tool_parser.py" did not exist on "8fcaaf6a165e661f63fc51be906bc05b0767332f"
benchmark_paged_attention.py 7.02 KB
Newer Older
1
2
3
import argparse
import random
import time
4
from typing import Optional
5
6
7

import torch

8
from vllm import _custom_ops as ops
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
10
11
12
13
14
15
16
17
18

NUM_BLOCKS = 1024
PARTITION_SIZE = 512


@torch.inference_mode()
def main(
    version: str,
    num_seqs: int,
19
    seq_len: int,
20
21
22
23
24
25
26
27
    num_query_heads: int,
    num_kv_heads: int,
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    seed: int,
    do_profile: bool,
28
    device: str = "cuda",
29
    kv_cache_dtype: Optional[str] = None,
30
31
32
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
33
34
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
35
36
37
38
39
40

    scale = float(1.0 / (head_size**0.5))
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
41
                        device=device)
42
43
44
45
46
47
48
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
49
                                   device=device)
50

51
52
53
    seq_lens = [seq_len for _ in range(num_seqs)]
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
54
55

    # Create the block tables.
56
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
57
58
59
60
61
62
63
    block_tables = []
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
64
    block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
65
66

    # Create the KV cache.
67
68
69
70
71
72
73
74
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
                                                            block_size,
                                                            1,
                                                            num_kv_heads,
                                                            head_size,
                                                            kv_cache_dtype,
                                                            dtype,
                                                            device=device)
75
    key_cache, value_cache = key_caches[0], value_caches[0]
76
77
78
79

    # Prepare for the paged attention kernel.
    output = torch.empty_like(query)
    if version == "v2":
80
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
81
82
83
84
85
86
87
88
89
90
91
92
        tmp_output = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_query_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)

93
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
94
95
96
97
98
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

99
100
101
        # Using default kv_scale
        kv_scale = 1.0

102
103
        for _ in range(num_iters):
            if version == "v1":
104
                ops.paged_attention_v1(
105
106
107
108
                    output,
                    query,
                    key_cache,
                    value_cache,
109
                    num_kv_heads,
110
111
                    scale,
                    block_tables,
112
                    seq_lens,
113
                    block_size,
114
                    max_seq_len,
115
                    alibi_slopes,
116
                    kv_cache_dtype,
117
                    kv_scale,
118
119
                )
            elif version == "v2":
120
                ops.paged_attention_v2(
121
122
123
124
125
126
127
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
128
                    num_kv_heads,
129
130
                    scale,
                    block_tables,
131
                    seq_lens,
132
                    block_size,
133
                    max_seq_len,
134
                    alibi_slopes,
135
                    kv_cache_dtype,
136
                    kv_scale,
137
138
139
140
141
142
143
144
145
146
147
148
                )
            else:
                raise ValueError(f"Invalid version: {version}")
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
149
    run_benchmark = run_cuda_benchmark
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    run_benchmark(num_iters=3, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=100, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Benchmark the paged attention kernel.")
    parser.add_argument("--version",
                        type=str,
                        choices=["v1", "v2"],
                        default="v2")
    parser.add_argument("--batch-size", type=int, default=8)
168
    parser.add_argument("--seq_len", type=int, default=4096)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    parser.add_argument("--num-query-heads", type=int, default=64)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
                        choices=[64, 80, 96, 112, 128, 256],
                        default=128)
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
    parser.add_argument("--use-alibi", action="store_true")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
183
184
185
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
186
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
187
        default="auto",
188
189
190
        help="Data type for kv cache storage. If 'auto', will use model "
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
        "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
191
192
193
194
195
196
197
198
    args = parser.parse_args()
    print(args)

    if args.num_query_heads % args.num_kv_heads != 0:
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
    main(
        version=args.version,
        num_seqs=args.batch_size,
199
        seq_len=args.seq_len,
200
201
202
203
204
        num_query_heads=args.num_query_heads,
        num_kv_heads=args.num_kv_heads,
        head_size=args.head_size,
        block_size=args.block_size,
        use_alibi=args.use_alibi,
205
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
206
207
        seed=args.seed,
        do_profile=args.profile,
208
        kv_cache_dtype=args.kv_cache_dtype,
209
    )