example_gqa_bwd_wgmma_pipelined.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse


@tilelang.jit(
9
10
    out_idx=[3, 4],
    pass_configs={
11
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
12
13
    },
)
14
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
15
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
16
17
18
19
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
20
21
    dtype = T.float16
    accum_dtype = T.float32
22
23
24

    @T.prim_func
    def flash_fwd(
25
26
27
28
29
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        Output: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
            K_shared = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_N, dim_v], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
45
            T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
46
47
48
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
49
            loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
50
            for k in T.Pipelined(loop_range, num_stages=1):
51
                T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
52
53
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
54
                        acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
55
                else:
56
                    for i, j in T.Parallel(block_M, block_N):
57
                        acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
58
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
59
                T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
60
61
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
62
63
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
64
65
66
67
68
69
70
71
72
73
74
75
76
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, dim_v):
                    acc_o[i, j] *= scores_scale[i]
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            for i, j in T.Parallel(block_M, dim_v):
                acc_o[i, j] /= logsum[i]
77
            T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
78
79
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
80
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
81
82
83
84
85

    return flash_fwd


@tilelang.jit(
86
87
    out_idx=[2],
    pass_configs={
88
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
89
90
    },
)
91
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
92
93
    dtype = T.float16
    accum_dtype = T.float32
94
95
96
97
98
    shape = [batch, seq_len, heads, dim_v]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
99
100
101
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
102
103
104
105
106
107
108
109
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim_v, blk)):
110
111
                T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
112
113
114
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
115
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
116
117
118
119

    return flash_bwd_prep


120
121
122
123
124
125
126
127
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
    sm_scale = (1.0 / dim_qk) ** 0.5
    scale = (1.0 / dim_qk) ** 0.5 * 1.44269504  # log2(e)
128
129
130
131
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim_qk]
    k_shape = [batch, seq_len, head_kv, dim_qk]
    v_shape = [batch, seq_len, head_kv, dim_v]
132
133
    dtype = T.float16
    accum_dtype = T.float32
134
135
136

    @T.prim_func
    def flash_bwd(
137
138
139
140
141
142
143
144
145
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        K: T.Tensor(k_shape, dtype),  # type: ignore
        V: T.Tensor(v_shape, dtype),  # type: ignore
        dO: T.Tensor([batch, seq_len, heads, dim_v], dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
        dK: T.Tensor(k_shape, accum_dtype),  # type: ignore
        dV: T.Tensor(v_shape, accum_dtype),  # type: ignore
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim_qk], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim_qk], dtype)
            V_shared = T.alloc_shared([block_M, dim_v], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim_v], dtype)
            dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
            dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
            dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
            dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
164
            dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
165

166
167
168
169
170
171
172
173
174
175
176
            T.annotate_layout(
                {
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                }
            )

            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
177
178
179
180
181
            T.clear(dv)
            T.clear(dk)
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
            loop_ed = T.ceildiv(seq_len, block_N)
            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
182
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
183
                T.clear(qkT)
184
185
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
186
187
188
189
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
190
191
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
192
                T.clear(dsT)
193
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
194
195
196
197
                T.wait_wgmma(1)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)

198
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
199
200
201
202
203
204
205
206
207
208

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.wait_wgmma(0)
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
                T.wait_wgmma(0)
209
                T.copy(dq, dq_shared)
210
                T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared)
211
            T.copy(dv, dv_shared)
212
            T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
213
            T.copy(dk, dk_shared)
214
            T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
215
216
217
218
219
220
221

    return flash_bwd


@torch.compile
class _attention(torch.autograd.Function):
    @staticmethod
222
    def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
223
224
225
226
227
228
229
230
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
        D_HEAD_V = v.shape[-1]
        block_M = 128
        block_N = 64
        mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
        o, lse = mod(q, k, v)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.causal = causal
231
        ctx.use_atomic = use_atomic
232
233
234
235
236
237
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
        BATCH, N_CTX, H, D_HEAD_QK = q.shape
238
239
240
241
        (
            HEAD_KV,
            D_HEAD_V,
        ) = v.shape[-2], v.shape[-1]
242
243
244
245
246
247
248
249
250
251
252
253
        groups = H // HEAD_KV

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
        block_M = 128
        block_N = 32
        mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
        delta = mod_prep(o, do)
254

255
        kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups)
256
257
258
259
260
261
262
263
264
265
        shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
        shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
        shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
        dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
        dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
        dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
        kernel(q, k, v, do, lse, delta, dq, dk, dv)
        dq = dq.to(torch.float16)
        dk = dk.to(torch.float16)
        dv = dv.to(torch.float16)
266
267

        return dq, dk, dv, None, None, None
268
269
270
271
272
273
274
275
276
277


attention = _attention.apply


def ref_program(Q, K, V, is_causal, groups=1):
    # Q: [B, T, HQ, D_QK]
    # K: [B, T, HK, D_QK]
    # V: [B, T, HV, D_V]
    # HQ = HKV * groups
278
279
    assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
    assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
280
281
282
283

    dim_qk = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
284
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
285
286
287
288
289
    scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
290
        scores = scores.masked_fill(mask == 0, float("-inf"))
291
    attention_weights = F.softmax(scores, dim=-1)
292
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
293
294
295
    return output


296
def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False):
297
298
299
300
301
    flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
    flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
    total_flops = 3 * flops_per_qk + 2 * flops_per_v
    if causal:
        total_flops *= 0.5
302
    Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
303
304

    head_kv = H // groups
305
306
307
    K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
    V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
    dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
308
    O = attention(Q, K, V, causal, groups)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None

    O_ref = ref_program(Q, K, V, causal, groups)
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None

    torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
324
    print("All checks passed.✅")
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    def run():
        O_ref.backward(dO, retain_graph=True)

    def run1():
        O.backward(dO, retain_graph=True)

    from tilelang.profiler import do_bench

    latency = do_bench(run, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(run1, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
344
345
346
347
348
349
350
    parser.add_argument("--batch", type=int, default=8, help="Batch size")
    parser.add_argument("--h", type=int, default=32, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
    parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
    parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
    parser.add_argument("--causal", action="store_true", help="Causal flag")
    parser.add_argument("--groups", type=int, default=16, help="groups")
351
    args = parser.parse_args()
352

353
    main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)