example_gqa_bwd_tma_reduce.py 24.2 KB
Newer Older
1
2
3
4
5
6
7
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse

8
9
tilelang.disable_cache()

10
11

@tilelang.jit(
12
13
    out_idx=[3, 4],
    pass_configs={
14
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
15
16
    },
)
17
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
18
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
19
20
21
22
23
24
25
26
27
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
28
29
30
31
32
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        Output: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
            K_shared = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_N, dim_v], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
48
            T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
49
50
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
51
52
53
            # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
            # We should set it to negative large number instead
            T.fill(scores_max, T.Cast(accum_dtype, -1e30))
54
            loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
55
            for k in T.Pipelined(loop_range, num_stages=1):
56
                T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
57
58
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
59
                        acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30))
60
                else:
61
                    for i, j in T.Parallel(block_M, block_N):
62
                        acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
63
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
64
                T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
65
66
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
67
68
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
69
70
71
72
73
74
75
76
77
78
79
80
81
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, dim_v):
                    acc_o[i, j] *= scores_scale[i]
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            for i, j in T.Parallel(block_M, dim_v):
                acc_o[i, j] /= logsum[i]
82
            T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
83
84
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
85
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
86
87
88
89
90

    return flash_fwd


@tilelang.jit(
91
92
    out_idx=[2],
    pass_configs={
93
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
94
95
    },
)
96
97
98
99
100
101
102
103
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
    dtype = "float16"
    accum_dtype = "float"
    shape = [batch, seq_len, heads, dim_v]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
104
105
106
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
107
108
109
110
111
112
113
114
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim_v, blk)):
115
116
                T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
117
118
119
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
120
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
121
122
123
124
125
126
127
128
129
130

    return flash_bwd_prep


def make_dq_layout(dQ):
    # bshd -> bhld to use tma reduction instruction
    return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])


@tilelang.jit(
131
132
    out_idx=[3, 4, 5],
    pass_configs={
133
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
134
135
    },
)
136
137
138
139
140
141
142
143
144
145
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
    dtype = "float16"
    accum_dtype = "float"
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    blk = 64

    @T.prim_func
    def flash_bwd_post(
146
147
148
149
150
151
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(k_shape, accum_dtype),  # type: ignore
        dV: T.Tensor(v_shape, accum_dtype),  # type: ignore
        dQ_out: T.Tensor(q_shape, dtype),  # type: ignore
        dK_out: T.Tensor(k_shape, dtype),  # type: ignore
        dV_out: T.Tensor(v_shape, dtype),  # type: ignore
152
153
154
    ):
        with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
            T.annotate_layout({dQ: make_dq_layout(dQ)})
155
            T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :])
156
        with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
157
158
159
160
161
162
163
164
            T.annotate_layout(
                {
                    dK: make_dq_layout(dK),
                    dV: make_dq_layout(dV),
                }
            )
            T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :])
            T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :])
165
166
167
168

    return flash_bwd_post


169
170
171
172
173
174
175
176
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
    sm_scale = (1.0 / dim_qk) ** 0.5
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
177
178
179
180
181
182
183
184
185
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_bwd(
186
187
188
189
190
191
192
193
194
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        dO: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(k_shape, accum_dtype),  # type: ignore
        dV: T.Tensor(v_shape, accum_dtype),  # type: ignore
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim_qk], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_M, dim_v], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim_v], dtype)
            dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
            dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
            dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
            dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
            dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)

215
216
217
218
219
220
221
222
223
224
225
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    dK: make_dq_layout(dK),
                    dV: make_dq_layout(dV),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                }
            )

            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
226
227
228
229
230
            T.clear(dv)
            T.clear(dk)
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
            loop_ed = T.ceildiv(seq_len, block_N)
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
231
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
232
233
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
234
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
235
236
237
238
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
239
240
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
241
242
243
244
245
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

246
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
247
248
249
250
251
252
253
254
255

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                T.copy(dq, dq_shared)
256
                T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True)
257
            T.copy(dv, dv_shared)
258
            T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
259
            T.copy(dk, dk_shared)
260
            T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
261
262
263
264

    return flash_bwd


265
266
267
268
269
270
271
272
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
    sm_scale = (1.0 / dim_qk) ** 0.5
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
273
274
275
276
277
278
279
280
281
282
283
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
    dk_shape = [groups, batch, seq_len, head_kv, dim_qk]  # sum after kernel
    dv_shape = [groups, batch, seq_len, head_kv, dim_v]  # sum after kernel
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_bwd(
284
285
286
287
288
289
290
291
292
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        dO: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(dk_shape, dtype),  # type: ignore
        dV: T.Tensor(dv_shape, dtype),  # type: ignore
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim_qk], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_M, dim_v], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim_v], dtype)
            dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
            dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
            dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim_v], dtype)
            dk_shared = T.alloc_shared([block_M, dim_qk], dtype)

312
313
314
315
316
317
318
319
320
321
322
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
                }
            )

            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
323
324
325
326
327
            T.clear(dv)
            T.clear(dk)
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
            loop_ed = T.ceildiv(seq_len, block_N)
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
328
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
329
330
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
331
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
332
333
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
334
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
335
336
337
338
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
339
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
340
341
342
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

343
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
344
345
346
347
348
349
350
351
352
353
354
355

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                for i, j in T.Parallel(block_N, dim_qk):
                    T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])

            T.copy(dv, dv_shared)
356
            T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
357
            T.copy(dk, dk_shared)
358
            T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

    return flash_bwd


@torch.compile
class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
        D_HEAD_V = v.shape[-1]
        block_M = 128
        block_N = 64
        mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
        o, lse = mod(q, k, v)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.causal = causal
        ctx.use_atomic = use_atomic
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
382
383
384
385
        (
            HEAD_KV,
            D_HEAD_V,
        ) = v.shape[-2], v.shape[-1]
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        groups = H // HEAD_KV

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
        block_M = 128
        block_N = 32
        mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
        mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
        delta = mod_prep(o, do)

        if ctx.use_atomic:
            kernel = flashattn_bwd_atomic_add(
402
403
                BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
            )
404
405
406
407
408
409
410
411
412
            shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
            shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
            shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
            dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
            dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
            dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
            kernel(q, k, v, do, lse, delta, dq, dk, dv)
            dq, dk, dv = mod_post(dq, dk, dv)
        else:
413
            kernel = flashattn_bwd_split_novarlen(
414
415
                BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
            )
416
417
418
419
420
421
422
            shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
            shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK]  # sum after kernel
            shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V]  # sum after kernel
            dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
            dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
            dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
            kernel(q, k, v, do, lse, delta, dq, dk, dv)
423
            dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
424
425
426
427
428
429
430
431
432
433
434
435
436
            dk, dv = dk.sum(0), dv.sum(0)

        return dq, dk, dv, None, None, None


attention = _attention.apply


def ref_program(Q, K, V, is_causal, groups=1):
    # Q: [B, T, HQ, D_QK]
    # K: [B, T, HK, D_QK]
    # V: [B, T, HV, D_V]
    # HQ = HKV * groups
437
438
    assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
    assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
439
440
441
442

    dim_qk = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
443
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
444
445
446
447
448
    scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
449
        scores = scores.masked_fill(mask == 0, float("-inf"))
450
    attention_weights = F.softmax(scores, dim=-1)
451
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
452
453
454
    return output


455
456
457
458
459
460
461
462
463
464
def main(
    BATCH: int = 1,
    H: int = 32,
    N_CTX: int = 256,
    D_HEAD_QK: int = 192,
    D_HEAD_V: int = 128,
    groups: int = 16,
    causal: bool = False,
    use_atomic: bool = True,
):
465
466
467
468
469
    flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
    flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
    total_flops = 3 * flops_per_qk + 2 * flops_per_v
    if causal:
        total_flops *= 0.5
470
    Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
471
472

    head_kv = H // groups
473
474
475
    K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
    V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
    dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    O = attention(Q, K, V, causal, groups, use_atomic)
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None

    O_ref = ref_program(Q, K, V, causal, groups)
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None

    torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
492
    print("All checks passed.✅")
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

    def run():
        O_ref.backward(dO, retain_graph=True)

    def run1():
        O.backward(dO, retain_graph=True)

    from tilelang.profiler import do_bench

    latency = do_bench(run, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(run1, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
    arch = nvcc.get_target_compute_version()
    print(f"Detected GPU compute capability: {arch}")
    assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
    parser = argparse.ArgumentParser()
515
516
517
518
519
520
521
522
523
    parser.add_argument("--batch", type=int, default=8, help="Batch size")
    parser.add_argument("--h", type=int, default=32, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
    parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
    parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
    parser.add_argument("--causal", action="store_true", help="Causal flag")
    parser.add_argument("--groups", type=int, default=16, help="groups")
    parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
    parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
524
525
526
527
528
529
530
531
532
533
534
    args = parser.parse_args()

    # Handle backward compatibility and logic
    if args.use_split:
        use_atomic = False
    elif args.use_atomic:
        use_atomic = True
    else:
        # Default: use atomic
        use_atomic = True

535
    main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)