benchmark_trtllm_decode_attention.py 8.23 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import csv
import os
from datetime import datetime
7
from typing import Optional
8
9
10
11
12

import flashinfer
import torch

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13
FP8_DTYPE = torch.float8_e4m3fn
14
15
16
17
18
19
20
21
22
23
24
25
26


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
    scale = finfo.max / amax * 0.1
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()


@torch.no_grad()
def benchmark_decode(
27
28
29
30
31
32
33
34
35
36
37
38
    dtype: torch.dtype,
    quant_dtypes: tuple[
        Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
    ],
    batch_size: int,
    max_seq_len: int,
    num_heads: tuple[int, int] = (64, 8),
    head_size: int = 128,
    kv_layout: str = "HND",
    block_size: int = 16,
    warmup: int = 10,
    trials: int = 20,
39
40
41
42
):
    torch.set_default_device("cuda")
    torch.manual_seed(0)

43
44
45
46
    q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
    q_quant_dtype = q_quant_dtype or dtype
    kv_quant_dtype = kv_quant_dtype or dtype
    o_quant_dtype = o_quant_dtype or dtype
47

48
49
    num_qo_heads, num_kv_heads = num_heads
    assert num_qo_heads % num_kv_heads == 0
50

51
    sm_scale = float(1.0 / (head_size**0.5))
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    # large number to reduce kv_cache reuse
    NUM_BLOCKS = int(256000 / block_size)

    kv_cache_shape = None
    if kv_layout == "NHD":
        kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
    elif kv_layout == "HND":
        kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
    else:
        raise ValueError(f"Invalid kv_layout: {kv_layout}")

    query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
    if q_quant_dtype == FP8_DTYPE:
        query, q_scale = to_float8(query)
        ref_query = query.to(dtype) * q_scale
    else:
        q_scale = 1.0
        ref_query = query

    kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
    kv_lens[-1] = max_seq_len

    seq_lens = kv_lens
    max_seq_len = torch.max(seq_lens).item()

    kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
    if kv_quant_dtype == FP8_DTYPE:
        kv_cache, kv_scale = to_float8(kv_cache)
        ref_kv_cache = kv_cache.to(dtype) * kv_scale
    else:
        kv_scale = 1.0
        ref_kv_cache = kv_cache
    k_scale = v_scale = kv_scale

    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
88
    block_tables = torch.randint(
89
        0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
90
91
92
93
    )
    kv_indptr = [0]
    kv_indices = []
    kv_last_page_lens = []
94
95
    for i in range(batch_size):
        seq_len = seq_lens[i]
96
        assert seq_len > 0
97
        num_blocks = (seq_len + block_size - 1) // block_size
98
99
        kv_indices.extend(block_tables[i, :num_blocks])
        kv_indptr.append(kv_indptr[-1] + num_blocks)
100
        kv_last_page_len = seq_len % block_size
101
        if kv_last_page_len == 0:
102
            kv_last_page_len = block_size
103
104
105
106
107
        kv_last_page_lens.append(kv_last_page_len)

    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
    kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
    kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
108
    workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
109

110
111
112
    wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer,
        kv_layout,
113
        use_tensor_cores=True,
114
115
116
117
118
119
120
    )
    wrapper.plan(
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        num_qo_heads,
        num_kv_heads,
121
122
        head_size,
        block_size,
123
        "NONE",
124
        sm_scale=sm_scale,
125
        q_data_type=dtype,
126
        kv_data_type=dtype,
127
128
    )

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def time_fn(fn, warmup=10, trials=20):
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        times = []
        for i in range(warmup):
            fn()
        for i in range(trials):
            start.record()
            fn()
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))  # ms
        return sum(times) / len(times), torch.std(torch.tensor(times))

    o_scale = 1.0
    output_baseline = torch.empty(ref_query.shape, dtype=dtype)
    output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)

148
    def baseline_decode():
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)

    def trtllm_decode():
        return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
            query=query,
            kv_cache=kv_cache,
            workspace_buffer=workspace_buffer,
            block_tables=block_tables,
            seq_lens=seq_lens,
            max_seq_len=max_seq_len,
            bmm1_scale=q_scale * k_scale * sm_scale,
            bmm2_scale=v_scale / o_scale,
            out=output_trtllm,
        )
163
164

    baseline_mean, baseline_std = time_fn(baseline_decode)
165
    trtllm_mean, trtllm_std = time_fn(trtllm_decode)
166
167

    # Calculate percentage speedup (positive means TRT is faster)
168
    speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
169
170

    print(
171
        f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
172
173
174
175
176
        f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
    )

    # Return results for CSV writing
    return {
177
178
179
        "batch_size": batch_size,
        "trtllm_mean": trtllm_mean,
        "trtllm_std": trtllm_std.item(),
180
181
182
        "baseline_mean": baseline_mean,
        "baseline_std": baseline_std.item(),
        "speedup_percent": speedup_percent,
183
184
185
186
        "q_dtype": str(q_quant_dtype),
        "kv_cache_dtype": str(kv_quant_dtype),
        "output_dtype": str(o_quant_dtype),
        "block_size": block_size,
187
        "num_kv_heads": num_kv_heads,
188
        "head_size": head_size,
189
190
191
192
193
194
195
196
197
198
199
        "max_seq_len": max_seq_len,
    }


def write_results_to_csv(results, filename=None):
    """Write benchmark results to CSV file."""
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"

    fieldnames = [
200
201
202
        "batch_size",
        "trtllm_mean",
        "trtllm_std",
203
204
205
206
207
        "baseline_mean",
        "baseline_std",
        "speedup_percent",
        "q_dtype",
        "kv_cache_dtype",
208
209
        "output_dtype",
        "block_size",
210
        "num_kv_heads",
211
        "head_size",
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        "max_seq_len",
    ]

    file_exists = os.path.exists(filename)

    with open(filename, "a", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        for result in results:
            writer.writerow(result)

    print(f"Results written to {filename}")


if __name__ == "__main__":
230
    batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
231
232
233
    max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
    all_results = []

234
235
236
237
238
239
240
    dtype = torch.bfloat16
    quant_dtypes = [
        # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
        (None, None, None),
        (None, FP8_DTYPE, None),
        (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
    ]
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    for quant_dtype in quant_dtypes:
        q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
        q_quant_dtype = q_quant_dtype or dtype
        kv_quant_dtype = kv_quant_dtype or dtype
        o_quant_dtype = o_quant_dtype or dtype

        print(
            f"Running benchmark for q_dtype = {q_quant_dtype}, "
            f"kv_cache_dtype: {kv_quant_dtype}, "
            f"output_dtype: {o_quant_dtype}"
        )
        print(
            "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
            "baseline_std\tspeedup_percent"
        )
        for max_seq_len in max_seq_lens:
            for bs in batch_sizes:
                result = benchmark_decode(
                    dtype=dtype,
                    quant_dtypes=quant_dtype,
                    batch_size=bs,
                    max_seq_len=max_seq_len,
                )
                all_results.append(result)
266
267
268

    # Write all results to CSV
    write_results_to_csv(all_results)