example_mha_sink_bwd_bhsd.py 21.3 KB
Newer Older
1
2
3
4
5
6
7
# Adapted from tilelang/examples/flash_attention/example_mha_bwd_bhsd.py

import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
8
from typing import Optional
9
10
11
12
13
14


def get_bwd_configs():
    sm_major, sm_minor = torch.cuda.get_device_capability()
    sm_version = sm_major * 10 + sm_minor
    if sm_version == 80:
15
        return 64, 32, 1, 128
16
    elif sm_version == 90:
17
        return 128, 32, 2, 256
18
19
20
21
22
    else:
        raise ValueError(f"Unsupported SM version: {sm_version}")


@tilelang.jit(
23
24
    out_idx=[3, 4],
    pass_configs={
25
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26
27
    },
)
28
def flashattn_fwd(
29
30
31
32
33
34
35
36
37
38
39
40
    batch,
    heads,
    seq_len,
    dim,
    window_size=None,  # None for full attention,
    sm_scale=None,
    block_M=64,
    block_N=64,
    num_stages=1,
    threads=128,
    dtype: str = "float16",
):
41
42
43
    if window_size is not None:
        assert window_size % block_N == 0, "window_size must be divisible by block_N"

44
    if sm_scale is None:
45
        sm_scale = (1.0 / dim) ** 0.5
46
47
    scale = sm_scale * 1.44269504  # log2(e)

48
49
50
51
52
    shape = [batch, heads, seq_len, dim]
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
53
54
55
56
57
58
        Q: T.Tensor(shape, dtype),  # type: ignore
        K: T.Tensor(shape, dtype),  # type: ignore
        V: T.Tensor(shape, dtype),  # type: ignore
        Output: T.Tensor(shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Sinks: T.Tensor([heads], dtype),  # type: ignore
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)
            sinks = T.alloc_fragment([heads], dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
75
            T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
76
77
78
79
80
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
            for i in T.Parallel(block_M):
                sinks[i] = Sinks[by]
81

82
            end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
83
            start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
84

85
            for k in T.Pipelined(start, end, num_stages=num_stages):
86
                T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
87
88
89
90
                for i, j in T.Parallel(block_M, block_N):
                    q_idx = bx * block_M + i
                    k_idx = k * block_N + j
                    if window_size is not None:
91
                        acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
92
93
94
95
                    else:
                        acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

96
                T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
97
98
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
99
100
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
101
102
103
104
105
                # To do causal softmax, we need to set the scores_max to 0 if it is -inf
                # This process is called Check_inf in FlashAttention3 code, and it only need to be done
                # NOTE(wt): check_inf is necessary for sliding window attention.
                for i in T.Parallel(block_M):
                    if window_size is not None:
106
                        scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)

                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] *= scores_scale[i]

                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)

                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]

            for i in T.Parallel(block_M):
123
                logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale)  # The only change for attention sink
124
125
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
126
            T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
127
128
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
129
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
130
131
132
133
134

    return flash_fwd


@tilelang.jit(
135
136
    out_idx=[2],
    pass_configs={
137
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
138
139
    },
)
140
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
141
142
143
144
145
146
    accum_dtype = "float"
    shape = [batch, heads, seq_len, dim]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
147
148
149
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
150
151
152
153
154
155
156
157
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim, blk)):
158
159
                T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
160
161
162
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
163
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
164
165
166
167
168
169

    return flash_bwd_prep


def make_dq_layout(dQ):
    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
170
    return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
171
172
173


@tilelang.jit(
174
175
    out_idx=[1],
    pass_configs={
176
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
177
178
    },
)
179
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
180
181
182
183
184
185
    accum_dtype = "float"
    shape = [batch, heads, seq_len, dim]
    blk = 64

    @T.prim_func
    def flash_bwd_post(
186
187
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dQ_out: T.Tensor(shape, dtype),  # type: ignore
188
189
190
191
    ):
        with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
            T.annotate_layout({dQ: make_dq_layout(dQ)})
            T.copy(
192
193
                dQ[bz, by, bx * blk : (bx + 1) * blk, :],
                dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
194
195
196
197
198
            )

    return flash_bwd_post


199
200
201
202
203
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
204
def flashattn_bwd(
205
206
207
208
209
210
211
    batch,
    heads,
    seq_len,
    dim,
    window_size=None,  # None for full attention
    sm_scale=None,
    dtype: str = "float16",
212
213
214
):
    block_M, block_N, num_stages, threads = get_bwd_configs()

215
    if sm_scale is None:
216
        sm_scale = (1.0 / dim) ** 0.5
217
    scale = sm_scale * 1.44269504  # log2(e)
218

219
220
221
222
223
224
225
226
    shape = [batch, heads, seq_len, dim]
    accum_dtype = "float"

    if window_size is not None:
        assert window_size % block_N == 0, "window_size must be divisible by block_N"

    @T.prim_func
    def flash_bwd(
227
228
229
230
231
232
233
234
235
        Q: T.Tensor(shape, dtype),  # type: ignore
        K: T.Tensor(shape, dtype),  # type: ignore
        V: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dK: T.Tensor(shape, dtype),  # type: ignore
        dV: T.Tensor(shape, dtype),  # type: ignore
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            # should not store K to local if dim is large
            # K_local = T.alloc_fragment([block_M, dim], dtype)
            # K_local_T = T.alloc_fragment([block_M, dim], dtype)
            # V_local = T.alloc_fragment([block_M, dim], dtype)
            q = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_M, dim], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim], dtype)
            dv = T.alloc_fragment([block_M, dim], accum_dtype)
            dk = T.alloc_fragment([block_M, dim], accum_dtype)
            dq = T.alloc_fragment([block_N, dim], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim], dtype)
            dk_shared = T.alloc_shared([block_M, dim], dtype)

259
260
261
262
263
264
265
266
267
268
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
                }
            )
            T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared)
            T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared)
269
270
271
272
            T.clear(dv)
            T.clear(dk)

            loop_st = T.floordiv(by * block_M, block_N)
273
274
275
276
277
            loop_ed = (
                T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
                if window_size is not None
                else T.ceildiv(seq_len, block_N)
            )
278
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
279
                T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
280
281
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
282
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
283
284
285
286
287
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                for i, j in T.Parallel(block_M, block_N):
                    if window_size is not None:
                        qkT[i, j] = T.if_then_else(
288
289
                            by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
                        )
290
                    else:
291
292
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
                T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
293
294
295
296
297
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow)

298
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
299
300
301
302
303
304
305
306

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
307
                T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
308

309
310
            T.copy(dv, dv_shared)
            T.copy(dk, dk_shared)
311
312
            T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :])
            T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :])
313
314
315
316
317

    return flash_bwd


@tilelang.jit(out_idx=-1)
318
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"):
319
320
321
322
323
    accum_dtype = "float"
    shape = [batch, heads, seq_len]

    @T.prim_func
    def flash_bwd_dsink(
324
325
326
327
        Sinks: T.Tensor([heads], dtype),  # type: ignore
        Delta: T.Tensor(shape, accum_dtype),  # type: ignore
        lse: T.Tensor(shape, accum_dtype),  # type: ignore
        dsinks: T.Tensor(shape, accum_dtype),  # type: ignore
328
329
330
331
332
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
            sink = T.alloc_local([1], dtype)
            lse_fragment = T.alloc_fragment([block], accum_dtype)
            delta_fragment = T.alloc_fragment([block], accum_dtype)
333
            dsink_fragment = T.alloc_fragment([block], accum_dtype)
334
335

            sink[0] = Sinks[bx]
336
337
            T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
            T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
338
            for i in T.Parallel(block):
339
340
                dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
            T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
341
342
343
344
345
346
347
348

    return flash_bwd_dsink


class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, sinks, window_size):
        BATCH, H, N_CTX, D_HEAD = q.shape
349
350
        dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
        kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        o, lse = kernel(q, k, v, sinks)
        ctx.save_for_backward(q, k, v, sinks, o, lse)
        ctx.window_size = window_size
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, sinks, o, lse = ctx.saved_tensors
        BATCH, H, N_CTX, D_HEAD = q.shape

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
367
368
369
        dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
        kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
        kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
370
        delta = kernel_prep(o, do)
371
        kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype)
372
373
        shape = [BATCH, H, N_CTX, D_HEAD]
        dq = torch.zeros(shape, dtype=torch.float32, device=q.device)  # acc for atomicAdd
374
375
        dk = torch.empty(shape, dtype=q.dtype, device=q.device)
        dv = torch.empty(shape, dtype=q.dtype, device=q.device)
376
377
378
        kernel(q, k, v, do, lse, delta, dq, dk, dv)
        dq = kernel_post(dq)

379
        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
380
381
382
383
384
385
386
387
388
        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
        return dq, dk, dv, dsinks, None


attention = _attention.apply


# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
389
390
391
392
393
394
395
396
397
def ref_program(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    sinks: torch.Tensor,
    sliding_window: Optional[int] = None,
    dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    query = query.transpose(1, 2).contiguous().unsqueeze(3)  # align with the original function's interface
398
399
400
401
402
403
404
405
406
    key = key.transpose(1, 2).contiguous()
    value = value.transpose(1, 2).contiguous()

    batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
    batch_size, num_keys, num_key_value_heads, head_dim = key.shape
    start_q = num_keys - num_queries

    sm_scale: float = 1.0 / head_dim**0.5

407
    sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    key = key.unsqueeze(3)
    value = value.unsqueeze(3)

    pos_keys = torch.arange(num_keys, device=query.device)
    pos_queries = torch.arange(num_queries, device=query.device) + start_q
    mask = pos_keys[None, :] > pos_queries[:, None]
    mask = mask.float().masked_fill(mask, float("-inf"))

    if sliding_window:
        too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
        mask.masked_fill_(too_old, float("-inf"))

    logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
    logits = logits + mask[None, None, None, :, :]

    logits_max = torch.max(logits, dim=-1, keepdim=True).values
    logits_or_sinks_max = torch.maximum(sinks, logits_max)
    sinks = torch.exp(sinks - logits_or_sinks_max)
    unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
    normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
    scores = unnormalized_scores / normalizer

    output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())

432
    output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
433
434
435
    return output.transpose(1, 2).contiguous()


436
def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"):
437
    torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
438
    if window_size is not None:
439
        print("Using sliding window attention.")
440
        assert window_size <= N_CTX
441
        flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD  # just a rough estimation
442
    else:
443
        print("Using full attention.")
444
445
446
        flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
    total_flops = 5 * flops_per_matmul

447
    Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
448
449
450
    K = torch.randn_like(Q).requires_grad_()
    V = torch.randn_like(Q).requires_grad_()
    sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_()
451
452
453
454
455
456
457
458
459
    dO = torch.randn_like(Q)

    O = attention(Q, K, V, sinks, window_size)
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None
    dsinks, sinks.grad = sinks.grad.clone(), None

460
    O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
461
462
463
464
465
466
467
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None
    dsinks_ref, sinks.grad = sinks.grad.clone(), None

    # Checks
468
469
470
471
    rtol, atol = {
        "float16": (1e-2, 1e-2),
        "bfloat16": (2e-2, 2e-2),
    }[dtype]
472
473
474
475
476
    assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
    assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
    assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
    assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
    assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

    print("All checks passed for tilelang kernels.✅")

    # Only benchmark backward here
    def torch_bwd():
        O_ref.backward(dO, retain_graph=True)

    def tl_bwd():
        O.backward(dO, retain_graph=True)

    latency = do_bench(torch_bwd, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(tl_bwd, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
497
498
499
500
501
502
    parser.add_argument("--batch", type=int, default=1, help="Batch size")
    parser.add_argument("--h", type=int, default=64, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
    parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
    parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
    parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
503
    args = parser.parse_args()
504
    main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype)