example_gqa_sink_bwd_bhsd.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
# Adapted from tilelang/examples/flash_attention/example_gqa_bwd.py

import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
8
from typing import Optional
9
10
11
12
13
14


def get_bwd_configs():
    sm_major, sm_minor = torch.cuda.get_device_capability()
    sm_version = sm_major * 10 + sm_minor
    if sm_version == 80:
15
        return 64, 32, 1, 128
16
    elif sm_version == 90:
17
        return 128, 32, 2, 256
18
19
20
21
22
    else:
        raise ValueError(f"Unsupported SM version: {sm_version}")


@tilelang.jit(
23
24
    out_idx=[3, 4],
    pass_configs={
25
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26
27
    },
)
28
def flashattn_fwd(
29
30
31
32
33
34
35
36
37
38
39
40
41
    batch,
    heads,
    seq_len,
    dim,
    groups=1,
    window_size=None,  # None for full attention
    sm_scale=None,
    block_M=64,
    block_N=64,
    num_stages=1,
    threads=128,
    dtype: str = "float16",
):
42
43
44
    if window_size is not None:
        assert window_size % block_N == 0, "window_size must be divisible by block_N"

45
    if sm_scale is None:
46
        sm_scale = (1.0 / dim) ** 0.5
47
48
    scale = sm_scale * 1.44269504  # log2(e)

49
50
51
52
53
54
55
    head_kv = heads // groups
    q_shape = [batch, heads, seq_len, dim]
    kv_shape = [batch, head_kv, seq_len, dim]
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
56
57
58
59
60
61
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(kv_shape, dtype),  # type: ignore
        V: T.Tensor(kv_shape, dtype),  # type: ignore
        Output: T.Tensor(q_shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Sinks: T.Tensor([heads], dtype),  # type: ignore
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)
            sinks = T.alloc_fragment([heads], dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
78
            T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
79
80
81
82
83
84
85
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
            for i in T.Parallel(block_M):
                sinks[i] = Sinks[by]

            end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
86
            start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
87

88
            for k in T.Pipelined(start, end, num_stages=num_stages):
89
                T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
90
91
92
93
                for i, j in T.Parallel(block_M, block_N):
                    q_idx = bx * block_M + i
                    k_idx = k * block_N + j
                    if window_size is not None:
94
                        acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
95
96
97
98
                    else:
                        acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

99
                T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
100
101
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
102
103
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
104
105
106
107
108
                # To do causal softmax, we need to set the scores_max to 0 if it is -inf
                # This process is called Check_inf in FlashAttention3 code, and it only need to be done
                # NOTE(wt): check_inf is necessary for sliding window attention.
                for i in T.Parallel(block_M):
                    if window_size is not None:
109
                        scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)

                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] *= scores_scale[i]

                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)

                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]

            for i in T.Parallel(block_M):
126
                logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale)  # The only change for attention sink
127
128
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
129
            T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
130
131
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
132
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
133
134
135
136
137

    return flash_fwd


@tilelang.jit(
138
139
    out_idx=[2],
    pass_configs={
140
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
141
142
    },
)
143
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
144
145
146
147
148
149
    accum_dtype = "float"
    shape = [batch, heads, seq_len, dim]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
150
151
152
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
153
154
155
156
157
158
159
160
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim, blk)):
161
162
                T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
163
164
165
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
166
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
167
168
169
170
171
172

    return flash_bwd_prep


def make_dq_layout(dQ):
    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
173
    return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
174
175
176


@tilelang.jit(
177
178
    out_idx=[1],
    pass_configs={
179
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
180
181
    },
)
182
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
183
184
185
186
187
188
    accum_dtype = "float"
    shape = [batch, heads, seq_len, dim]
    blk = 64

    @T.prim_func
    def flash_bwd_post(
189
190
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dQ_out: T.Tensor(shape, dtype),  # type: ignore
191
192
193
194
    ):
        with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
            T.annotate_layout({dQ: make_dq_layout(dQ)})
            T.copy(
195
196
                dQ[bz, by, bx * blk : (bx + 1) * blk, :],
                dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
197
198
199
200
201
            )

    return flash_bwd_post


202
203
204
205
206
207
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"):  # None for full attention
208
    if sm_scale is None:
209
        sm_scale = (1.0 / dim) ** 0.5
210
    scale = sm_scale * 1.44269504  # log2(e)
211

212
213
214
215
216
217
218
219
220
221
222
223
    head_kv = heads // groups
    q_shape = [batch, heads, seq_len, dim]
    kv_shape = [batch, head_kv, seq_len, dim]
    accum_dtype = "float"

    block_M, block_N, num_stages, threads = get_bwd_configs()

    if window_size is not None:
        assert window_size % block_N == 0, "window_size must be divisible by block_N"

    @T.prim_func
    def flash_bwd(
224
225
226
227
228
229
230
231
232
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(kv_shape, dtype),  # type: ignore
        V: T.Tensor(kv_shape, dtype),  # type: ignore
        dO: T.Tensor(q_shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(kv_shape, accum_dtype),  # type: ignore
        dV: T.Tensor(kv_shape, accum_dtype),  # type: ignore
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_M, dim], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim], dtype)
            dv = T.alloc_fragment([block_M, dim], accum_dtype)
            dk = T.alloc_fragment([block_M, dim], accum_dtype)
            dq = T.alloc_fragment([block_N, dim], accum_dtype)
249
250
            dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
            dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
251

252
253
254
255
256
257
258
259
260
261
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
                }
            )
            T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared)
            T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared)
262
263
264
265
            T.clear(dv)
            T.clear(dk)

            loop_st = T.floordiv(by * block_M, block_N)
266
267
268
269
270
            loop_ed = (
                T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
                if window_size is not None
                else T.ceildiv(seq_len, block_N)
            )
271
272

            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
273
                T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
274
275
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
276
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
277
278
279
280
281
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                for i, j in T.Parallel(block_M, block_N):
                    if window_size is not None:
                        qkT[i, j] = T.if_then_else(
282
283
                            by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
                        )
284
                    else:
285
286
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
                T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
287
288
289
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
290
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
291

292
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
293
294
295
296
297
298
299
300

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
301
                T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
302

303
            T.copy(dv, dv_shared)
304
            T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared)
305
            T.copy(dk, dk_shared)
306
            T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared)
307
308
309
310
311

    return flash_bwd


@tilelang.jit(out_idx=-1)
312
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"):
313
314
315
316
317
    accum_dtype = "float"
    shape = [batch, heads, seq_len]

    @T.prim_func
    def flash_bwd_dsink(
318
319
320
321
        Sinks: T.Tensor([heads], dtype),  # type: ignore
        Delta: T.Tensor(shape, accum_dtype),  # type: ignore
        lse: T.Tensor(shape, accum_dtype),  # type: ignore
        dsinks: T.Tensor(shape, dtype),  # type: ignore
322
323
324
325
326
327
328
329
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
            sink = T.alloc_local([1], dtype)
            lse_fragment = T.alloc_fragment([block], accum_dtype)
            delta_fragment = T.alloc_fragment([block], accum_dtype)
            dsink_fragment = T.alloc_fragment([block], dtype)

            sink[0] = Sinks[bx]
330
331
            T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
            T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
332
            for i in T.Parallel(block):
333
334
                dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
            T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
335
336
337
338
339
340
341

    return flash_bwd_dsink


class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, sinks, window_size, groups):
342
343
344
345
346
347
        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
348
        BATCH, H, N_CTX, D_HEAD = q.shape
349
350
        dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
        kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
351
352
353
354
355
356
357
358
359
360
361
        o, lse = kernel(q, k, v, sinks)
        ctx.save_for_backward(q, k, v, sinks, o, lse)
        ctx.window_size = window_size
        ctx.groups = groups
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, sinks, o, lse = ctx.saved_tensors
        BATCH, H, N_CTX, D_HEAD = q.shape
        groups = ctx.groups
362
        dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
363

364
365
        kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
        kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
366
        delta = kernel_prep(o, do)
367
        kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype)
368
369
370
371
        q_shape = [BATCH, H, N_CTX, D_HEAD]
        head_kv = H // groups
        kv_shape = [BATCH, head_kv, N_CTX, D_HEAD]
        dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device)  # acc for atomicAdd
372
373
        dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
        dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
374
375
376
        kernel(q, k, v, do, lse, delta, dq, dk, dv)
        dq = kernel_post(dq)

377
        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
378
379
380
381
382
383
384
385
386
        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
        return dq, dk, dv, dsinks, None, None


attention = _attention.apply


# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
387
388
389
390
391
392
393
394
def ref_program(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    sinks: torch.Tensor,
    sliding_window: Optional[int] = None,
    dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    key = key.transpose(1, 2).contiguous()
    value = value.transpose(1, 2).contiguous()
    batch_size, num_keys, num_key_value_heads, head_dim = key.shape
    query = query.transpose(1, 2).contiguous()
    query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
    batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape

    start_q = num_keys - num_queries
    sm_scale: float = 1.0 / head_dim**0.5

    sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
    key = key.unsqueeze(3)
    value = value.unsqueeze(3)

    pos_keys = torch.arange(num_keys, device=query.device)
    pos_queries = torch.arange(num_queries, device=query.device) + start_q
    mask = pos_keys[None, :] > pos_queries[:, None]
    mask = mask.float().masked_fill(mask, float("-inf"))

    if sliding_window:
        too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
        mask.masked_fill_(too_old, float("-inf"))

    logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
    logits = logits + mask[None, None, None, :, :]

    logits_max = torch.max(logits, dim=-1, keepdim=True).values
    logits_or_sinks_max = torch.maximum(sinks, logits_max)
    sinks = torch.exp(sinks - logits_or_sinks_max)
    unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
    normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
    scores = unnormalized_scores / normalizer

    output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())

430
    output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
431
432
433
    return output.transpose(1, 2).contiguous()


434
435
436
437
438
439
440
441
442
def main(
    BATCH: int = 1,
    H: int = 8,
    N_CTX: int = 512,
    D_HEAD: int = 64,
    groups: int = 2,
    window_size: Optional[int] = None,
    dtype: str = "float16",
):
443
    torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
444
    if window_size is not None:
445
        print("Using sliding window attention.")
446
        assert window_size <= N_CTX
447
        flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD  # just a rough estimation
448
    else:
449
        print("Using full attention.")
450
451
452
        flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
    total_flops = 5 * flops_per_matmul

453
454
    Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
    K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
455
456
    V = torch.randn_like(K).requires_grad_()
    sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_()
457
458
459
460
461
462
463
464
465
    dO = torch.randn_like(Q)

    O = attention(Q, K, V, sinks, window_size, groups)
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None
    dsinks, sinks.grad = sinks.grad.clone(), None

466
    O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
467
468
469
470
471
472
473
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None
    dsinks_ref, sinks.grad = sinks.grad.clone(), None

    # Checks
474
475
476
477
    rtol, atol = {
        "float16": (1e-2, 1e-2),
        "bfloat16": (2e-2, 2e-2),
    }[dtype]
478
479
480
481
482
    assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
    assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
    assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
    assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
    assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502

    print("All checks passed for tilelang kernels.✅")

    # Only benchmark backward here
    def torch_bwd():
        O_ref.backward(dO, retain_graph=True)

    def tl_bwd():
        O.backward(dO, retain_graph=True)

    latency = do_bench(torch_bwd, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(tl_bwd, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
503
504
505
506
507
508
509
    parser.add_argument("--batch", type=int, default=1, help="Batch size")
    parser.add_argument("--h", type=int, default=64, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
    parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
    parser.add_argument("--groups", type=int, default=8, help="Groups")
    parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
    parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
510
    args = parser.parse_args()
511
    main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype)