benchmark_trtllm_prefill_attention.py 9.4 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import csv
import os
from datetime import datetime
7
from typing import Optional
8
9
10
11

import flashinfer
import torch

12
13
from vllm.utils import round_up

14
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
15
FP8_DTYPE = torch.float8_e4m3fn
16
FP4_DTYPE = torch.uint8
17
18
19
20
21
22
23
24
25
26
27
28


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
    scale = finfo.max / amax * 0.1
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()


@torch.no_grad()
29
def benchmark_prefill(
30
31
32
33
34
35
36
37
38
39
40
41
    dtype: torch.dtype,
    quant_dtypes: tuple[
        Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
    ],
    batch_size: int,
    max_seq_len: int,
    num_heads: tuple[int, int] = (64, 8),
    head_size: int = 128,
    kv_layout: str = "HND",
    block_size: int = 16,
    warmup: int = 10,
    trials: int = 20,
42
43
44
45
):
    torch.set_default_device("cuda")
    torch.manual_seed(0)

46
47
48
49
    q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
    q_quant_dtype = q_quant_dtype or dtype
    kv_quant_dtype = kv_quant_dtype or dtype
    o_quant_dtype = o_quant_dtype or dtype
50

51
    max_q_len = max_kv_len = max_seq_len
52

53
54
    num_qo_heads, num_kv_heads = num_heads
    assert num_qo_heads % num_kv_heads == 0
55

56
    sm_scale = float(1.0 / (head_size**0.5))
57

58
59
60
61
62
63
64
65
66
67
68
69
70
    # large number to reduce kv_cache reuse
    NUM_BLOCKS = int(256000 / block_size)

    kv_cache_shape = None
    if kv_layout == "NHD":
        kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
    elif kv_layout == "HND":
        kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
    else:
        raise ValueError(f"Invalid kv_layout: {kv_layout}")

    q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
    q_lens[-1] = max_q_len
71
72
73
    q_indptr = torch.cat(
        [
            torch.tensor([0], dtype=torch.int32),
74
            torch.cumsum(q_lens, dim=0, dtype=torch.int32),
75
76
77
        ]
    )

78
79
80
81
82
    # Always using 1.0 scale to reflect the real perf in benchmarking
    q_scale = 1.0
    ref_query = torch.randn(
        torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
    )
83
    if q_quant_dtype == FP8_DTYPE:
84
        query, _ = to_float8(ref_query)
85
    else:
86
        query = ref_query
87
88
89
90
91
92
93

    kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
    kv_lens[-1] = max_kv_len

    seq_lens = kv_lens + q_lens
    max_seq_len = torch.max(seq_lens).item()

94
95
96
    # Always using 1.0 scale to reflect the real perf in benchmarking
    k_scale = v_scale = 1.0
    ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
97
    if kv_quant_dtype == FP8_DTYPE:
98
        kv_cache, _ = to_float8(ref_kv_cache)
99
    else:
100
        kv_cache = ref_kv_cache
101
102

    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
103
    block_tables = torch.randint(
104
        0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
105
106
107
108
    )
    kv_indptr = [0]
    kv_indices = []
    kv_last_page_lens = []
109
    for i in range(batch_size):
110
        seq_len = seq_lens[i]
111
        assert seq_len > 0
112
        num_blocks = (seq_len + block_size - 1) // block_size
113
114
        kv_indices.extend(block_tables[i, :num_blocks])
        kv_indptr.append(kv_indptr[-1] + num_blocks)
115
        kv_last_page_len = seq_len % block_size
116
        if kv_last_page_len == 0:
117
            kv_last_page_len = block_size
118
119
120
121
122
        kv_last_page_lens.append(kv_last_page_len)

    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
    kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
    kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
123
    workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
124

125
126
    wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout
127
128
    )
    wrapper.plan(
129
        q_indptr,
130
131
132
133
134
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        num_qo_heads,
        num_kv_heads,
135
136
        head_size,
        block_size,
137
138
        causal=True,
        sm_scale=sm_scale,
139
        q_data_type=dtype,
140
        kv_data_type=dtype,
141
142
    )

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    def time_fn(fn, warmup=10, trials=20):
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        times = []
        for i in range(warmup):
            fn()
        for i in range(trials):
            start.record()
            fn()
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))  # ms
        return sum(times) / len(times), torch.std(torch.tensor(times))

158
    o_scale = 1.0
159
    o_sf_scale = None
160
    output_baseline = torch.empty(ref_query.shape, dtype=dtype)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    if o_quant_dtype == FP4_DTYPE:
        o_sf_scale = 500.0
        output_trtllm = flashinfer.utils.FP4Tensor(
            torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
            torch.empty(
                (
                    round_up(query.shape[0], 128),
                    round_up(query.shape[1] * query.shape[2] // 16, 4),
                ),
                dtype=torch.float8_e4m3fn,
            ),
        )
    else:
        output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
175

176
    def baseline_prefill():
177
178
179
180
181
182
183
        return wrapper.run(
            ref_query,
            ref_kv_cache,
            k_scale=k_scale,
            v_scale=v_scale,
            out=output_baseline,
        )
184

185
    def trtllm_prefill():
186
        return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
187
            query=query,
188
189
190
            kv_cache=kv_cache,
            workspace_buffer=workspace_buffer,
            block_tables=block_tables,
191
            seq_lens=seq_lens,
192
193
            max_q_len=max_q_len,
            max_kv_len=max_seq_len,
194
195
196
            bmm1_scale=q_scale * k_scale * sm_scale,
            bmm2_scale=v_scale / o_scale,
            batch_size=batch_size,
197
198
            cum_seq_lens_q=q_indptr,
            cum_seq_lens_kv=kv_indptr,
199
            o_sf_scale=o_sf_scale,
200
201
            out=output_trtllm,
        )
202

203
    baseline_mean, baseline_std = time_fn(baseline_prefill)
204
    trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
205
206

    # Calculate percentage speedup (positive means TRT is faster)
207
    speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
208
209

    print(
210
211
        f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
        f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
212
213
214
215
    )

    # Return results for CSV writing
    return {
216
217
218
        "batch_size": batch_size,
        "trtllm_mean": trtllm_mean,
        "trtllm_std": trtllm_std.item(),
219
220
221
        "baseline_mean": baseline_mean,
        "baseline_std": baseline_std.item(),
        "speedup_percent": speedup_percent,
222
223
224
225
        "q_dtype": str(q_quant_dtype),
        "kv_cache_dtype": str(kv_quant_dtype),
        "output_dtype": str(o_quant_dtype),
        "block_size": block_size,
226
        "num_kv_heads": num_kv_heads,
227
        "head_size": head_size,
228
229
230
231
232
233
234
235
236
237
238
        "max_seq_len": max_seq_len,
    }


def write_results_to_csv(results, filename=None):
    """Write benchmark results to CSV file."""
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"

    fieldnames = [
239
240
241
        "batch_size",
        "trtllm_mean",
        "trtllm_std",
242
243
244
245
246
        "baseline_mean",
        "baseline_std",
        "speedup_percent",
        "q_dtype",
        "kv_cache_dtype",
247
248
        "output_dtype",
        "block_size",
249
        "num_kv_heads",
250
        "head_size",
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        "max_seq_len",
    ]

    file_exists = os.path.exists(filename)

    with open(filename, "a", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        for result in results:
            writer.writerow(result)

    print(f"Results written to {filename}")


if __name__ == "__main__":
269
    batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
270
271
272
    max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
    all_results = []

273
274
275
276
277
    dtype = torch.bfloat16
    quant_dtypes = [
        # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
        (None, None, None),
        (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
278
        (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    ]

    for quant_dtype in quant_dtypes:
        q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
        q_quant_dtype = q_quant_dtype or dtype
        kv_quant_dtype = kv_quant_dtype or dtype
        o_quant_dtype = o_quant_dtype or dtype

        print(
            f"Running benchmark for q_dtype = {q_quant_dtype}, "
            f"kv_cache_dtype: {kv_quant_dtype}, "
            f"output_dtype: {o_quant_dtype}"
        )
        print(
            "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
            "baseline_std\tspeedup_percent"
        )
        for max_seq_len in max_seq_lens:
            for bs in batch_sizes:
                result = benchmark_prefill(
                    dtype=dtype,
                    quant_dtypes=quant_dtype,
                    batch_size=bs,
                    max_seq_len=max_seq_len,
                )
                all_results.append(result)
305
306
307

    # Write all results to CSV
    write_results_to_csv(all_results)