flash_attn_triton.py 30.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
"""
Tri Dao's avatar
Tri Dao committed
2
We use the FlashAttention implementation from Phil Tillet a starting point.
Tri Dao's avatar
Tri Dao committed
3
4
5
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

Changes:
6
7
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
8
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
9
10
11
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
12
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
Tri Dao's avatar
Tri Dao committed
13
- Make the backward for d=128 much faster by reducing register spilling.
14
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
Tri Dao's avatar
Tri Dao committed
15
small batch size * nheads.
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22

Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
slower in other cases.
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
Tri Dao's avatar
Tri Dao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
"""

import math

import torch

import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
Tri Dao's avatar
Tri Dao committed
36
37
        # This config has a race condition when EVEN_M == False, disabling it for now.
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
Tri Dao's avatar
Tri Dao committed
38
39
40
41
42
43
    ],
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
)
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
44
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
45
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
Tri Dao's avatar
Tri Dao committed
46
47
48
49
50
51
52
53
54
55
56
    }
)
@triton.jit
def _fwd_kernel(
    Q, K, V, Out,
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_ob, stride_oh, stride_om,
57
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
58
59
60
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
61
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    # initialize pointer to m and l
83
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
84
85
86
87
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
88
89
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
90
    if EVEN_M & EVEN_N:
91
92
93
94
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
95
    else:
96
97
98
99
100
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
Tri Dao's avatar
Tri Dao committed
106
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
107
108
109
110
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
111
        else:
112
113
114
115
116
117
118
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
        if not EVEN_N:
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
        m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
        # Slightly faster to multiply the softmax_scale here since the compiler can then
        # fuse the mult and add into an fma instruction.
        p = tl.exp(qk * softmax_scale - m_ij[:, None])
        l_ij = tl.sum(p, 1)

        # scale acc_o
        acc_o_scale = tl.exp(m_i - m_ij)

        # # -- update output accumulator --
        # BUG: have to store and immediately load
        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)
        acc_o = acc_o * acc_o_scale[:, None]
        # update acc_o
Tri Dao's avatar
Tri Dao committed
140
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
141
142
143
144
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
145
        else:
146
147
148
149
150
151
152
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- update statistics
        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    o_scale = tl.exp(m_i - lse_i)
    # BUG: have to store and immediately load
    tl.store(t_ptrs, o_scale)
    o_scale = tl.load(t_ptrs)
    acc_o = acc_o * o_scale[:, None]
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
170
    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
    tl.store(lse_ptrs, lse_i)
    # initialize pointers to output
    offs_n = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :])
    if EVEN_M:
176
177
178
179
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
Tri Dao's avatar
Tri Dao committed
180
    else:
181
182
183
184
185
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(out_ptrs, acc_o,
                     mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
192


@triton.jit
def _bwd_preprocess_do_o_dot(
    Out, DO, Delta,
    stride_ob, stride_oh, stride_om,
    stride_dob, stride_doh, stride_dom,
193
    nheads, seqlen_q, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
194
195
196
197
198
199
200
201
202
203
    BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # load
204
205
206
207
    o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
    do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
                 mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
Tri Dao's avatar
Tri Dao committed
208
209
210
211
212
213
214
215
216
217
218
219
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)


@triton.jit
def _bwd_kernel_one_col_block(
    start_n,
    Q, K, V, softmax_scale,
    DO, DQ, DK, DV,
    LSE, D,
    stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
220
    seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
221
222
223
    ATOMIC_ADD: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
224
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
232
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
    begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
    # initialize row/col offsets
    offs_qm = begin_m + tl.arange(0, BLOCK_M)
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
233
    offs_d = tl.arange(0, BLOCK_HEADDIM)
Tri Dao's avatar
Tri Dao committed
234
    # initialize pointers to value-like data
235
236
237
238
239
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
240
    # initialize dv and dk
Tri Dao's avatar
Tri Dao committed
241
242
243
    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    # k and v stay in SRAM throughout
244
245
    # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.load(k_ptrs), we get the wrong output!
246
    if EVEN_N & EVEN_M:
247
248
249
250
251
252
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs)
            v = tl.load(v_ptrs)
        else:
            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
253
    else:
254
255
256
257
258
259
260
261
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
            v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
        else:
            k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
            v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
262
263
264
265
266
267
    # loop over rows
    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        offs_m_curr = start_m + offs_m
        # load q, k, v, do on-chip
268
269
270
        # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
        if EVEN_M & EVEN_HEADDIM:
            q = tl.load(q_ptrs)
271
        else:
272
273
274
275
276
            if EVEN_HEADDIM:
                q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
            else:
                q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                         & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
277
278
        # recompute p = softmax(qk, dim=-1).T
        qk = tl.dot(q, k, trans_b=True)
279
280
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
Tri Dao's avatar
Tri Dao committed
281
282
        if IS_CAUSAL:
            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
283
        # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
284
285
        # Also wrong for headdim=64.
        if not EVEN_M:
286
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
287
288
289
        lse_i = tl.load(LSE + offs_m_curr)
        p = tl.exp(qk * softmax_scale - lse_i[:, None])
        # compute dv
290
291
292
293
294
        # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
        # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
        # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
        # the output is correct.
        if EVEN_M & EVEN_HEADDIM:
295
            do = tl.load(do_ptrs)
296
297
298
299
        else:
            # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
            do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_d[None, :] < headdim), other=0.0)
300
301
302
303
304
        # if EVEN_M:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs)
        #     else:
        #         do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
305
306
307
308
309
310
        # else:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
        #     else:
        #         do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
        #                                    & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
311
312
        dv += tl.dot(p.to(do.dtype), do, trans_a=True)
        # compute dp = dot(v, do)
313
        # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
314
        # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
315
        # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
Tri Dao's avatar
Tri Dao committed
316
317
        if not EVEN_M:
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
318
319
320
321
322
323
324
325
326
327
328
        dp = tl.dot(do, v, trans_b=True)
        # compute ds = p * (dp - delta[:, None])
        # Putting the subtraction after the dp matmul (instead of before) is slightly faster
        Di = tl.load(D + offs_m_curr)
        # Converting ds to q.dtype here reduces register pressure and makes it much faster
        # for BLOCK_HEADDIM=128
        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
        # compute dk = dot(ds.T, q)
        dk += tl.dot(ds, q, trans_a=True)
        # compute dq
        if not ATOMIC_ADD:
329
            if EVEN_M:
330
331
332
333
334
335
336
337
338
                if EVEN_HEADDIM:
                    dq = tl.load(dq_ptrs, eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, eviction_policy="evict_last")
                else:
                    dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0,
                                 eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last")
339
            else:
340
341
342
343
344
345
346
347
348
349
350
351
352
                if EVEN_HEADDIM:
                    dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
                                eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
                            eviction_policy="evict_last")
                else:
                    dq = tl.load(dq_ptrs,
                                 mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                                 other=0.0, eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq,
                             mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
353
                             eviction_policy="evict_last")
Tri Dao's avatar
Tri Dao committed
354
355
        else:  # If we're parallelizing across the seqlen_k dimension
            dq = tl.dot(ds, k)
356
            if EVEN_M:
357
358
359
360
                if EVEN_HEADDIM:
                    tl.atomic_add(dq_ptrs, dq)
                else:
                    tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
361
            else:
362
363
364
365
366
                if EVEN_HEADDIM:
                    tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
                else:
                    tl.atomic_add(dq_ptrs, dq,
                                  mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
367
368
369
370
371
        # increment pointers
        dq_ptrs += BLOCK_M * stride_dqm
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_dom
    # write-back
372
373
    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
374
375
376
    # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.store(dv_ptrs), there's a race condition
    if EVEN_N & EVEN_M:
377
378
379
380
381
382
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv)
            tl.store(dk_ptrs, dk)
        else:
            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
383
    else:
384
385
386
387
388
389
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
            tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
        else:
            tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
            tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
390
391
392
393
394
395
396
397
398
399
400
401
402


def init_to_zero(name):
    # def fn(nargs):
    #     with torch.no_grad():
    #         nargs[name].zero_()
    # return fn
    return lambda nargs: nargs[name].zero_()

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
403
404
405
406
407
408
        # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
        # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
Tri Dao's avatar
Tri Dao committed
409
410
411
412
413
414
415
416
417
418
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
    ],
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
    # reset_to_zero=['DQ']
)
419
420
421
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
422
423
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
424
425
    }
)
Tri Dao's avatar
Tri Dao committed
426
427
428
429
430
431
432
433
434
435
436
437
438
@triton.jit
def _bwd_kernel(
    Q, K, V,
    DO, DQ, DK, DV,
    LSE, D,
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_dob, stride_doh, stride_dom,
    stride_dqb, stride_dqh, stride_dqm,
    stride_dkb, stride_dkh, stride_dkn,
    stride_dvb, stride_dvh, stride_dvn,
439
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
440
441
442
443
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    SEQUENCE_PARALLEL: tl.constexpr,
444
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # offset pointers for batch/head
    Q += off_b * stride_qb + off_h * stride_qh
    K += off_b * stride_kb + off_h * stride_kh
    V += off_b * stride_vb + off_h * stride_vh
    DO += off_b * stride_dob + off_h * stride_doh
    DQ += off_b * stride_dqb + off_h * stride_dqh
    DK += off_b * stride_dkb + off_h * stride_dkh
    DV += off_b * stride_dvb + off_h * stride_dvh
    # pointer to row-wise quantities in value-like data
    D += off_hb * seqlen_q_rounded
    LSE += off_hb * seqlen_q_rounded
    if not SEQUENCE_PARALLEL:
        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
        for start_n in range(0, num_block_n):
            _bwd_kernel_one_col_block(
                start_n,
                Q, K, V, softmax_scale,
                DO, DQ, DK, DV,
                LSE, D,
                stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
470
                seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
471
472
473
                ATOMIC_ADD=False,
                IS_CAUSAL=IS_CAUSAL,
                BLOCK_HEADDIM=BLOCK_HEADDIM,
474
                EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
475
476
477
478
479
480
481
482
483
484
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
            )
    else:
        start_n = tl.program_id(0)
        _bwd_kernel_one_col_block(
            start_n,
            Q, K, V, softmax_scale,
            DO, DQ, DK, DV,
            LSE, D,
            stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
485
            seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
486
487
488
            ATOMIC_ADD=True,
            IS_CAUSAL=IS_CAUSAL,
            BLOCK_HEADDIM=BLOCK_HEADDIM,
489
            EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
490
491
492
493
494
495
496
497
498
499
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
        )


def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
    # shape constraints
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
    assert k.shape == (batch, seqlen_k, nheads, d)
    assert v.shape == (batch, seqlen_k, nheads, d)
500
    assert d <= 128, 'FlashAttention only support head dimensions up to 128'
Tri Dao's avatar
Tri Dao committed
501
502
503
504
505
506
507
508
509
510
511
    assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
    assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
    assert q.is_cuda and k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    # lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device,
                     # dtype=torch.float32)
    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    o = torch.empty_like(q)

512
    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
513
514
515
516
517
518
519
520
521
522
523
    # BLOCK = 128
    # num_warps = 4 if d <= 64 else 8
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _fwd_kernel[grid](
        q, k, v, o,
        lse, tmp,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
        o.stride(0), o.stride(2), o.stride(1),
524
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
525
526
527
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
528
        causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
529
530
531
532
533
534
535
536
537
538
539
540
541
        # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
        # num_warps=num_warps,
        # num_stages=1,
    )
    return o, lse, softmax_scale  # softmax_scale could have been updated


def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_scale=None):
    # Make sure that the last dimension is contiguous
    if do.stride(-1) != 1:
        do = do.contiguous()
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
542
543
    # assert d in {16, 32, 64, 128}
    assert d <= 128
Tri Dao's avatar
Tri Dao committed
544
545
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    assert lse.shape == (batch, nheads, seqlen_q_rounded)
546
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
Tri Dao's avatar
Tri Dao committed
547
548
549
550
    # dq_accum = torch.zeros_like(q, dtype=torch.float32)
    dq_accum = torch.empty_like(q, dtype=torch.float32)
    delta = torch.empty_like(lse)
    # delta = torch.zeros_like(lse)
551
552

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _bwd_preprocess_do_o_dot[grid](
        o, do, delta,
        o.stride(0), o.stride(2), o.stride(1),
        do.stride(0), do.stride(2), do.stride(1),
558
559
        nheads, seqlen_q, seqlen_q_rounded, d,
        BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    )

    # BLOCK_M = 128
    # BLOCK_N = 64
    # num_warps = 4
    grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
                    batch * nheads)
    _bwd_kernel[grid](
        q, k, v,
        do, dq_accum, dk, dv,
        lse, delta,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
        do.stride(0), do.stride(2), do.stride(1),
        dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
        dk.stride(0), dk.stride(2), dk.stride(1),
        dv.stride(0), dv.stride(2), dv.stride(1),
579
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
580
581
582
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
583
        causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        # SEQUENCE_PARALLEL=False,
        # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        # num_warps=num_warps,
        # num_stages=1,
    )
    dq.copy_(dq_accum)


class FlashAttnQKVPackedFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, qkv, causal=False, softmax_scale=None):
        """
            qkv: (batch, seqlen, 3, nheads, headdim)
        """
        # Make sure that the last dimension is contiguous
        if qkv.stride(-1) != 1:
            qkv = qkv.contiguous()
        o, lse, ctx.softmax_scale = _flash_attn_forward(
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], causal=causal, softmax_scale=softmax_scale
        )
        ctx.save_for_backward(qkv, o, lse)
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        qkv, o, lse = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
612
613
614
615
616
617
618
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dqkv = torch.empty_like(qkv)
            _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
                                dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
                                causal=ctx.causal, softmax_scale=ctx.softmax_scale)
Tri Dao's avatar
Tri Dao committed
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        return dqkv, None, None


flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply


class FlashAttnKVPackedFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, kv, causal=False, softmax_scale=None):
        """
            q: (batch, seqlen, nheads, headdim)
            kv: (batch, seqlen, 2, nheads, headdim)
        """
        # Make sure that the last dimension is contiguous
        q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
        o, lse, ctx.softmax_scale = _flash_attn_forward(
            q, kv[:, :, 0], kv[:, :, 1], causal=causal, softmax_scale=softmax_scale
        )
        ctx.save_for_backward(q, kv, o, lse)
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        q, kv, o, lse = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
645
646
647
648
649
650
651
652
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
                                dq, dkv[:, :, 0], dkv[:, :, 1],
                                causal=ctx.causal, softmax_scale=ctx.softmax_scale)
Tri Dao's avatar
Tri Dao committed
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        return dq, dkv, None, None


flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal=False, softmax_scale=None):
        """
            q, k, v: (batch_size, seqlen, nheads, headdim)
        """
        # Make sure that the last dimension is contiguous
        q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
        o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, causal=causal,
                                                        softmax_scale=softmax_scale)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
677
678
679
680
681
682
683
684
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
                                causal=ctx.causal, softmax_scale=ctx.softmax_scale)
Tri Dao's avatar
Tri Dao committed
685
686
687
688
        return dq, dk, dv, None, None


flash_attn_func = FlashAttnFunc.apply