example_mha_bwd_bshd.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse


9
@tilelang.jit(
10
11
    out_idx=[3, 4],
    pass_configs={
12
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
13
14
    },
)
15
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
16
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
17
18
19
20
21
22
    shape = [batch, seq_len, heads, dim]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_fwd(
23
24
25
26
27
        Q: T.Tensor(shape, dtype),  # type: ignore
        K: T.Tensor(shape, dtype),  # type: ignore
        V: T.Tensor(shape, dtype),  # type: ignore
        Output: T.Tensor(shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim], dtype)
            # Q_local = T.alloc_fragment([block_M, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

43
            T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
44
45
46
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))
47
            loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
48
            for k in T.Pipelined(loop_range, num_stages=1):
49
                T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
50
                if is_causal:
51
                    for i, j in T.Parallel(block_M, block_N):
52
                        acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
53
                else:
54
                    for i, j in T.Parallel(block_M, block_N):
55
                        acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
56
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
57
                T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
58
59
                T.copy(scores_max, scores_max_prev)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
60
61
                for i in T.Parallel(block_M):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
62
63
64
65
66
67
68
69
70
71
72
73
74
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] *= scores_scale[i]
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.copy(acc_s, acc_s_cast)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
75
            T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
76
77
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
78
            T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
79
80
81
82

    return flash_fwd


83
@tilelang.jit(
84
85
    out_idx=[2],
    pass_configs={
86
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
87
88
    },
)
89
90
91
92
93
94
95
96
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
    dtype = "float16"
    accum_dtype = "float"
    shape = [batch, seq_len, heads, dim]
    blk = 32

    @T.prim_func
    def flash_bwd_prep(
97
98
99
        O: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
100
101
102
103
104
105
106
107
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim, blk)):
108
109
                T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
                T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
110
111
112
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
113
            T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
114
115
116
117
118
119

    return flash_bwd_prep


def make_dq_layout(dQ):
    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
120
    return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
121
122


123
@tilelang.jit(
124
125
    out_idx=[1],
    pass_configs={
126
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
127
128
    },
)
129
130
131
132
133
134
135
136
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
    dtype = "float16"
    accum_dtype = "float"
    shape = [batch, seq_len, heads, dim]
    blk = 64

    @T.prim_func
    def flash_bwd_post(
137
138
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dQ_out: T.Tensor(shape, dtype),  # type: ignore
139
140
141
142
    ):
        with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
            T.annotate_layout({dQ: make_dq_layout(dQ)})
            T.copy(
143
144
                dQ[bz, bx * blk : (bx + 1) * blk, by, :],
                dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
145
146
147
148
149
            )

    return flash_bwd_post


150
151
152
153
154
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    }
)
155
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
156
157
    sm_scale = (1.0 / dim) ** 0.5
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
158
159
160
161
162
163
    shape = [batch, seq_len, heads, dim]
    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def flash_bwd(
164
165
166
167
168
169
170
171
172
        Q: T.Tensor(shape, dtype),  # type: ignore
        K: T.Tensor(shape, dtype),  # type: ignore
        V: T.Tensor(shape, dtype),  # type: ignore
        dO: T.Tensor(shape, dtype),  # type: ignore
        lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
        dQ: T.Tensor(shape, accum_dtype),  # type: ignore
        dK: T.Tensor(shape, dtype),  # type: ignore
        dV: T.Tensor(shape, dtype),  # type: ignore
173
    ):
174
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
175
176
            K_shared = T.alloc_shared([block_M, dim], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
177
178
179
180
            # should not store K to local if dim is large
            # K_local = T.alloc_fragment([block_M, dim], dtype)
            # K_local_T = T.alloc_fragment([block_M, dim], dtype)
            # V_local = T.alloc_fragment([block_M, dim], dtype)
181
182
183
184
185
186
187
188
189
190
191
192
            q = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_M, dim], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim], dtype)
            dv = T.alloc_fragment([block_M, dim], accum_dtype)
            dk = T.alloc_fragment([block_M, dim], accum_dtype)
            dq = T.alloc_fragment([block_N, dim], accum_dtype)
193
194
            dv_shared = T.alloc_shared([block_M, dim], dtype)
            dk_shared = T.alloc_shared([block_M, dim], dtype)
195

196
197
198
199
200
201
202
            T.annotate_layout(
                {
                    dQ: make_dq_layout(dQ),
                }
            )
            T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared)
            T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared)
203
204
            T.clear(dv)
            T.clear(dk)
205
            loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
206
            loop_ed = T.ceildiv(seq_len, block_N)
207
            for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
208
                T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
209
210
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
211
                T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
212
213
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
214
                if is_causal:
215
                    for i, j in T.Parallel(block_M, block_N):
216
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
217
218
                # We don't need to handle OOB positions for non-causal cases,
                # since OOB values won't affect other positions here.
219
                T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
220
221
222
223
224
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

225
                T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
226
227
228
229
230
231
232
233
234

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                for i, j in T.Parallel(block_N, dim):
235
                    T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
236
237
            T.copy(dv, dv_shared)
            T.copy(dk, dk_shared)
238
239
            T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :])
            T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :])
240
241
242
243
244
245

    return flash_bwd


class _attention(torch.autograd.Function):
    @staticmethod
246
    def forward(ctx, q, k, v, causal):
247
248
249
        BATCH, N_CTX, H, D_HEAD = q.shape
        block_M = 64
        block_N = 64 if D_HEAD <= 128 else 32
250
        o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
251
252
253
254
255
256
257
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
258
        BATCH, N_CTX, H, D_HEAD = q.shape
259
260
261
262
263
264
265

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
266
267
268
269
        block_M = 64
        block_N = 64 if D_HEAD <= 64 else 32
        kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
        kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
270
        delta = kernel_prep(o, do)
271
272
273
274
275
276
277
278
        kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
        shape = [BATCH, N_CTX, H, D_HEAD]
        dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
        dk = torch.empty(shape, dtype=torch.float16, device=q.device)
        dv = torch.empty(shape, dtype=torch.float16, device=q.device)
        kernel(q, k, v, do, lse, delta, dq, dk, dv)
        dq = kernel_post(dq)
        return dq, dk, dv, None
279
280
281
282
283
284
285


attention = _attention.apply


def ref_program(Q, K, V, is_causal):
    dim = Q.size(-1)
286
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
287
288
289
290
291
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
292
        scores = scores.masked_fill(mask == 0, float("-inf"))
293
    attention_weights = F.softmax(scores, dim=-1)
294
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
295
296
297
    return output


298
299
300
301
302
303
304
def main(
    BATCH: int = 8,
    H: int = 32,
    N_CTX: int = 1024,
    D_HEAD: int = 64,
    causal: bool = False,
):
305
306
    flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
    total_flops = 5 * flops_per_matmul
307
    if causal:
308
        total_flops *= 0.5
309
    Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
310
311
312
    K = torch.empty_like(Q).normal_().requires_grad_()
    V = torch.empty_like(Q).normal_().requires_grad_()
    dO = torch.randn_like(Q)
313
    O = attention(Q, K, V, causal)
314
315
316
317
318
    O.backward(dO, retain_graph=True)
    dQ, Q.grad = Q.grad.clone(), None
    dK, K.grad = K.grad.clone(), None
    dV, V.grad = V.grad.clone(), None

319
    O_ref = ref_program(Q, K, V, causal)
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    O_ref.backward(dO, retain_graph=True)
    dQ_ref, Q.grad = Q.grad.clone(), None
    dK_ref, K.grad = K.grad.clone(), None
    dV_ref, V.grad = V.grad.clone(), None

    assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
    assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
    assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
    assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)

    def run():
        O_ref.backward(dO, retain_graph=True)

    def run1():
        O.backward(dO, retain_graph=True)

    from tilelang.profiler import do_bench

    latency = do_bench(run, warmup=500)
    print("torch: {:.2f} ms".format(latency))
    print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    latency = do_bench(run1, warmup=500)
    print("tilelang: {:.2f} ms".format(latency))
    print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
344
345
346
347


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
348
349
350
351
352
    parser.add_argument("--batch", type=int, default=8, help="Batch size")
    parser.add_argument("--h", type=int, default=32, help="Number of heads")
    parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
    parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
    parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
353
    args = parser.parse_args()
354
    main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)