bench_attention_sink_triton.py 7.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import argparse

import torch
import triton

from sglang.srt.layers.attention.triton_ops.decode_attention import (
    decode_attention_fwd_grouped,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd

# gpt oss
head_num = 64
head_dim = 64
head_kv_num = 8


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["S"],  # sequence length on x-axis
        x_vals=[128, 256, 512, 1024, 2048, 4096],
        x_log=True,
        line_arg="B",  # batch size as different lines
        line_vals=[1, 8, 32, 128],
        line_names=["B=1", "B=8", "B=32", "B=128"],
        styles=[
            ("blue", "-"),
            ("green", "-"),
            ("red", "-"),
            ("cyan", "-"),
        ],
        ylabel="TFLOPS",
        plot_name="attention-sink-triton-decode",
        args={},
    )
)
def benchmark_decode(B, S, H_Q, H_KV, D):
    D_V = D
    dtype = torch.bfloat16
    seq_len = S
    total_tokens = B * seq_len
    device = torch.device("cuda")
    sm_scale = 1.0 / (D**0.5)
    max_kv_splits = 8
    num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")

    # q represents the new token being generated, one per batch
    q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

    # k_buffer and v_buffer represent all previous tokens
    k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
    v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")

    o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")

    b_seq_len = torch.full((B,), seq_len, device="cuda")

    kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
    kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
    kv_indices = torch.arange(total_tokens, device="cuda")

    attn_logits1 = torch.empty(
        (B, H_Q, max_kv_splits, D_V),
        dtype=torch.float32,
        device="cuda",
    )
    attn_lse1 = torch.empty(
        (B, H_Q, max_kv_splits, D_V),
        dtype=torch.float32,
        device="cuda",
    )
    sink = torch.randn(H_Q, device=device, dtype=torch.float32)

    # warmup
    for _ in range(5):
        decode_attention_fwd_grouped(
            q,
            k_buffer,
            v_buffer,
            o,
            kv_indptr,
            kv_indices,
            attn_logits1,
            attn_lse1,
            num_kv_splits,
            max_kv_splits,
            sm_scale,
            logit_cap=0.0,
            sinks=sink,
        )

    # benchmark
    run_step = 500
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(run_step):
        decode_attention_fwd_grouped(
            q,
            k_buffer,
            v_buffer,
            o,
            kv_indptr,
            kv_indices,
            attn_logits1,
            attn_lse1,
            num_kv_splits,
            max_kv_splits,
            sm_scale,
            logit_cap=0.0,
            sinks=sink,
        )
    end_event.record()
    end_event.synchronize()
    torch.cuda.synchronize()
    ms = start_event.elapsed_time(end_event) / run_step
    tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms  # must be causal
    return tflops(ms)


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["S"],  # sequence length on x-axis
        x_vals=[128, 256, 512, 1024, 2048, 4096],
        x_log=True,
        line_arg="B",  # batch size as different lines
        line_vals=[1, 8, 32, 128],
        line_names=["B=1", "B=8", "B=32", "B=128"],
        styles=[
            ("blue", "-"),
            ("green", "-"),
            ("red", "-"),
            ("cyan", "-"),
        ],
        ylabel="TFLOPS",
        plot_name="attention-sink-triton-extend",
        args={},
    )
)
def benchmark_extend(B, S, H_Q, H_KV, D):
    # S here represents N_CTX from the test
    dtype = torch.bfloat16
    device = "cuda"

    # Split S into prefix and extend lengths
    prefill_len = S // 2  # Similar to test's N_CTX // 2
    extend_len = S // 4  # Make extend length smaller than prefix

    # Calculate total tokens and extend tokens
    total_extend_tokens = B * extend_len
    total_prefix_tokens = B * prefill_len

    # Create query, key, value tensors for extension
    q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)
    k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
    v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
    o_extend = torch.empty_like(q_extend)

    # Create key-value buffers for prefix
    k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
    v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)

    # Create index pointers
    qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(
        torch.int32
    )
    kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(
        torch.int32
    )
    kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)

    sm_scale = 1.0 / (D**0.5)
    # sliding_window = 128  # From GPT-OSS config, skip for now
    sliding_window = -1

    sink = torch.randn(H_Q, device=device, dtype=torch.float32)

    # warmup
    for _ in range(5):
        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask=None,
            is_causal=True,
            mask_indptr=None,
            max_len_extend=extend_len,
            sm_scale=sm_scale,
            sliding_window_size=sliding_window,
            sinks=sink,
        )

    # benchmark
    run_step = 500
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(run_step):
        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask=None,
            is_causal=True,
            mask_indptr=None,
            max_len_extend=extend_len,
            sm_scale=sm_scale,
            sliding_window_size=sliding_window,
            sinks=sink,
        )
    end_event.record()
    end_event.synchronize()
    torch.cuda.synchronize()
    ms = start_event.elapsed_time(end_event) / run_step

    # FLOPS calculation: each attention operation requires 2 multiplications per element
    total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D
    tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3)  # convert to TFLOPS
    return tflops(ms)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--bench", type=str, default="all", help="all, extend, decode")
    args = parser.parse_args()

    kwargs = {
        "H_Q": head_num,
        "H_KV": head_kv_num,
        "D": head_dim,
    }

    if args.bench in ["all", "decode"]:
        benchmark_decode.run(print_data=True, show_plots=False, **kwargs)

    if args.bench in ["all", "extend"]:
        benchmark_extend.run(print_data=True, show_plots=False, **kwargs)

    print("Benchmark finished!")