example_gqa_bwd.py 22.8 KB
Newer Older
1
2
3
4
5
6
7
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse


8
@tilelang.jit(
9
10
    out_idx=[3, 4],
    pass_configs={
11
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
12
13
    },
)
14
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
15
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
16
17
18
19
20
21
22
23
24
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
25
26
27
28
29
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        Output: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
30
    ):
31
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
32
33
34
35
36
37
38
39
40
41
42
43
44
            Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
            K_shared = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_N, dim_v], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
45
            T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
46
47
48
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
49
            loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
50
            for k in T.Pipelined(loop_range, num_stages=1):
51
                T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
52
                if is_causal:
53
                    for i, j in T.Parallel(block_M, block_N):
54
                        acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
55
                else:
56
                    for i, j in T.Parallel(block_M, block_N):
57
                        acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
58
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
59
                T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
60
61
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
62
63
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
64
65
66
67
68
69
70
71
72
73
74
75
76
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, dim_v):
                    acc_o[i, j] *= scores_scale[i]
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            for i, j in T.Parallel(block_M, dim_v):
                acc_o[i, j] /= logsum[i]
77
            T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
78
79
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
80
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
81
82
83
84

    return flash_fwd


85
@tilelang.jit(
86
87
    out_idx=[2],
    pass_configs={
88
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
89
90
    },
)
91
92
93
94
95
96
97
98
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
    dtype = "float16"
    accum_dtype = "float"
    shape = [batch, seq_len, heads, dim_v]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
99
100
101
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
102
103
104
105
106
107
108
109
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim_v, blk)):
110
111
                T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
112
113
114
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
115
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
116
117
118
119
120
121

    return flash_bwd_prep


def make_dq_layout(dQ):
    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
122
    return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
123
124


125
@tilelang.jit(
126
127
    out_idx=[1],
    pass_configs={
128
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
129
130
    },
)
131
132
133
134
135
136
137
138
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
    dtype = "float16"
    accum_dtype = "float"
    shape = [batch, seq_len, heads, dim_qk]
    blk = 64

    @T.prim_func
    def flash_bwd_post(
139
140
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dQ_out: T.Tensor(shape, dtype),  # type: ignore
141
142
143
144
    ):
        with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
            T.annotate_layout({dQ: make_dq_layout(dQ)})
            T.copy(
145
146
                dQ[bz, bx * blk : (bx + 1) * blk, by, :],
                dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
147
148
149
150
151
            )

    return flash_bwd_post


152
153
154
155
156
157
158
159
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
    sm_scale = (1.0 / dim_qk) ** 0.5
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
160
161
162
163
164
165
166
167
168
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_bwd(
169
170
171
172
173
174
175
176
177
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        dO: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(k_shape, accum_dtype),  # type: ignore
        dV: T.Tensor(v_shape, accum_dtype),  # type: ignore
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim_qk], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_M, dim_v], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim_v], dtype)
            dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
            dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
            dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
            dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)

197
198
199
200
201
202
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                }
            )
203

204
205
            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
206
207
208
209
210
            T.clear(dv)
            T.clear(dk)
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
            loop_ed = T.ceildiv(seq_len, block_N)
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
211
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
212
213
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
214
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
215
216
217
218
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
219
220
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
221
222
223
224
225
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

226
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
227
228
229
230
231
232
233
234
235

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                for i, j in T.Parallel(block_N, dim_qk):
236
                    T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
237
            T.copy(dv, dv_shared)
238
            T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
239
            T.copy(dk, dk_shared)
240
            T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
241
242
243
244

    return flash_bwd


245
246
247
248
249
250
251
252
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
    sm_scale = (1.0 / dim_qk) ** 0.5
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
253
254
255
256
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
257
258
    dk_shape = [groups, batch, seq_len, head_kv, dim_qk]  # sum after kernel
    dv_shape = [groups, batch, seq_len, head_kv, dim_v]  # sum after kernel
259
260
261
262
263
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_bwd(
264
265
266
267
268
269
270
271
272
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        dO: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(dk_shape, dtype),  # type: ignore
        dV: T.Tensor(dv_shape, dtype),  # type: ignore
273
    ):
274
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            K_shared = T.alloc_shared([block_M, dim_qk], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_M, dim_v], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim_v], dtype)
            dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
            dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
            dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
289
290
            dv_shared = T.alloc_shared([block_M, dim_v], dtype)
            dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
291

292
293
294
295
296
297
298
299
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
                }
            )
300

301
302
            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
303
304
            T.clear(dv)
            T.clear(dk)
305
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
306
            loop_ed = T.ceildiv(seq_len, block_N)
307
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
308
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
309
310
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
311
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
312
313
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
314
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
315
316
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
317
                if is_causal:
318
                    for i, j in T.Parallel(block_M, block_N):
319
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
320
321
322
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

323
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
324
325
326
327
328
329
330
331
332

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                for i, j in T.Parallel(block_N, dim_qk):
333
                    T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
334

335
            T.copy(dv, dv_shared)
336
            T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
337
            T.copy(dk, dk_shared)
338
            T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
339
340
341
342

    return flash_bwd


343
@torch.compile
344
345
class _attention(torch.autograd.Function):
    @staticmethod
346
    def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
347
348
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
        D_HEAD_V = v.shape[-1]
349
        block_M = 128
350
        block_N = 64
351
        mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
352
353
354
        o, lse = mod(q, k, v)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.causal = causal
355
        ctx.use_atomic = use_atomic
356
357
358
359
360
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
361
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
362
363
364
365
        (
            HEAD_KV,
            D_HEAD_V,
        ) = v.shape[-2], v.shape[-1]
366
        groups = H // HEAD_KV
367
368
369
370
371
372
373

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
374
        block_M = 128
375
376
377
        block_N = 32
        mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
        mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
378
        delta = mod_prep(o, do)
379
380
381

        if ctx.use_atomic:
            kernel = flashattn_bwd_atomic_add(
382
383
                BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
            )
384
385
386
387
388
389
390
391
392
393
394
395
            shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
            shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
            shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
            dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
            dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
            dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
            kernel(q, k, v, do, lse, delta, dq, dk, dv)
            dq = mod_post(dq)
            dk = dk.to(torch.float16)
            dv = dv.to(torch.float16)
        else:
            kernel = flashattn_bwd_split(
396
397
                BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
            )
398
399
400
401
402
403
404
405
406
407
408
            shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
            shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK]  # sum after kernel
            shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V]  # sum after kernel
            dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
            dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
            dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
            kernel(q, k, v, do, lse, delta, dq, dk, dv)
            dq = mod_post(dq)
            dk, dv = dk.sum(0), dv.sum(0)

        return dq, dk, dv, None, None, None
409
410
411
412
413
414
415
416
417
418


attention = _attention.apply


def ref_program(Q, K, V, is_causal, groups=1):
    # Q: [B, T, HQ, D_QK]
    # K: [B, T, HK, D_QK]
    # V: [B, T, HV, D_V]
    # HQ = HKV * groups
419
420
    assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
    assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
421
422
423
424

    dim_qk = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
425
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
426
427
428
429
430
    scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
431
        scores = scores.masked_fill(mask == 0, float("-inf"))
432
    attention_weights = F.softmax(scores, dim=-1)
433
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
434
435
436
    return output


437
438
439
440
441
442
443
444
445
446
def main(
    BATCH: int = 1,
    H: int = 32,
    N_CTX: int = 256,
    D_HEAD_QK: int = 192,
    D_HEAD_V: int = 128,
    groups: int = 16,
    causal: bool = False,
    use_atomic: bool = True,
):
447
448
449
    flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
    flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
    total_flops = 3 * flops_per_qk + 2 * flops_per_v
450
    if causal:
451
        total_flops *= 0.5
452
    Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
453
454

    head_kv = H // groups
455
456
457
    K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
    V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
    dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
458
    O = attention(Q, K, V, causal, groups, use_atomic)
459
460
461
462
463
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None

464
    O_ref = ref_program(Q, K, V, causal, groups)
465
466
467
468
469
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None

470
    torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
471
    torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
472
473
    torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
474
    print("All checks passed.✅")
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

    def run():
        O_ref.backward(dO, retain_graph=True)

    def run1():
        O.backward(dO, retain_graph=True)

    from tilelang.profiler import do_bench

    latency = do_bench(run, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(run1, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
490
491
492
493


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
494
495
496
497
498
499
500
501
502
    parser.add_argument("--batch", type=int, default=8, help="Batch size")
    parser.add_argument("--h", type=int, default=32, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
    parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
    parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
    parser.add_argument("--causal", action="store_true", help="Causal flag")
    parser.add_argument("--groups", type=int, default=16, help="groups")
    parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
    parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
503
    args = parser.parse_args()
504
505
506
507
508
509
510
511
512
513

    # Handle backward compatibility and logic
    if args.use_split:
        use_atomic = False
    elif args.use_atomic:
        use_atomic = True
    else:
        # Default: use atomic
        use_atomic = True

514
    main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)