flash_attn_triton.py 36.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
"""
2
3
*Experimental* implementation of FlashAttention in Triton.

Tri Dao's avatar
Tri Dao committed
4
We use the FlashAttention implementation from Phil Tillet a starting point.
Tri Dao's avatar
Tri Dao committed
5
6
7
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

Changes:
8
- Implement both causal and non-causal attention.
9
- Implement both self-attention and cross-attention.
10
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
11
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
12
- Support attention bias.
13
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
Tri Dao's avatar
Tri Dao committed
14
- Make the backward for d=128 much faster by reducing register spilling.
15
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
Tri Dao's avatar
Tri Dao committed
16
small batch size * nheads.
Tri Dao's avatar
Tri Dao committed
17

18
Caution:
19
20
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
21
22
23
24
25
26
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.

Tri Dao's avatar
Tri Dao committed
27
28
29
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
30
31
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
It is slightly slower when headdim=128 and batch * nheads is large.
32
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37
38
"""

import math

import torch

39
40
from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
41
42
43
44
45
46
47
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
Tri Dao's avatar
Tri Dao committed
48
49
        # This config has a race condition when EVEN_M == False, disabling it for now.
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
Tri Dao's avatar
Tri Dao committed
50
    ],
51
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
Tri Dao's avatar
Tri Dao committed
52
53
54
55
)
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
56
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
57
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
Tri Dao's avatar
Tri Dao committed
58
59
60
61
    }
)
@triton.jit
def _fwd_kernel(
62
    Q, K, V, Bias, Out,
Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
68
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
69
    stride_ob, stride_oh, stride_om,
70
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
71
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
72
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
73
74
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
75
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
96
97
98
99
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
Tri Dao's avatar
Tri Dao committed
100
    # initialize pointer to m and l
101
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
102
103
104
105
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
106
107
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
108
    if EVEN_M & EVEN_N:
109
110
111
112
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
113
    else:
114
115
116
117
118
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
Tri Dao's avatar
Tri Dao committed
124
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
125
126
127
128
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
129
        else:
130
131
132
133
134
135
136
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
137
138
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
139
140
        # Trying to combine the two masks seem to make the result wrong
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
Tri Dao's avatar
Tri Dao committed
141
142
143
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        if BIAS_TYPE != 'none':
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n,
                                   mask=(offs_m[:, None] < seqlen_q)
                                        & ((start_n + offs_n)[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
            # can then fuse the mult and add into an fma instruction. But if we have bias we need to
            # to multiply with softmax_scale here.
            qk = qk * softmax_scale + bias
            m_ij = tl.maximum(tl.max(qk, 1), lse_i)
            p = tl.exp(qk - m_ij[:, None])
        else:
            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
            p = tl.exp(qk * softmax_scale - m_ij[:, None])
Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
174
175
176
177
178
        l_ij = tl.sum(p, 1)

        # scale acc_o
        acc_o_scale = tl.exp(m_i - m_ij)

        # # -- update output accumulator --
        # BUG: have to store and immediately load
        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)
        acc_o = acc_o * acc_o_scale[:, None]
        # update acc_o
Tri Dao's avatar
Tri Dao committed
179
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
180
181
182
183
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
184
        else:
185
186
187
188
189
190
191
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- update statistics
        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    o_scale = tl.exp(m_i - lse_i)
    # BUG: have to store and immediately load
    tl.store(t_ptrs, o_scale)
    o_scale = tl.load(t_ptrs)
    acc_o = acc_o * o_scale[:, None]
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
209
    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
    tl.store(lse_ptrs, lse_i)
    # initialize pointers to output
    offs_n = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :])
    if EVEN_M:
215
216
217
218
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
Tri Dao's avatar
Tri Dao committed
219
    else:
220
221
222
223
224
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(out_ptrs, acc_o,
                     mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231


@triton.jit
def _bwd_preprocess_do_o_dot(
    Out, DO, Delta,
    stride_ob, stride_oh, stride_om,
    stride_dob, stride_doh, stride_dom,
232
    nheads, seqlen_q, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
238
239
240
241
242
    BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # load
243
244
245
246
    o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
    do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
                 mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
254
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)


@triton.jit
def _bwd_kernel_one_col_block(
    start_n,
255
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
256
257
    DO, DQ, DK, DV,
    LSE, D,
258
259
260
    softmax_scale,
    stride_qm, stride_kn, stride_vn, stride_bm,
    stride_dom, stride_dqm, stride_dkn, stride_dvn,
261
    seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
262
    ATOMIC_ADD: tl.constexpr,
263
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
264
265
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
266
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
267
268
269
270
271
272
273
274
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
    begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
    # initialize row/col offsets
    offs_qm = begin_m + tl.arange(0, BLOCK_M)
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
275
    offs_d = tl.arange(0, BLOCK_HEADDIM)
Tri Dao's avatar
Tri Dao committed
276
    # initialize pointers to value-like data
277
278
279
280
281
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
282
283
284
285
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
286
    # initialize dv and dk
Tri Dao's avatar
Tri Dao committed
287
288
289
    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    # k and v stay in SRAM throughout
290
291
    # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.load(k_ptrs), we get the wrong output!
292
    if EVEN_N & EVEN_M:
293
294
295
296
297
298
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs)
            v = tl.load(v_ptrs)
        else:
            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
299
    else:
300
301
302
303
304
305
306
307
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
            v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
        else:
            k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
            v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
308
309
310
311
312
313
    # loop over rows
    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        offs_m_curr = start_m + offs_m
        # load q, k, v, do on-chip
314
315
316
        # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
        if EVEN_M & EVEN_HEADDIM:
            q = tl.load(q_ptrs)
317
        else:
318
319
320
321
322
            if EVEN_HEADDIM:
                q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
            else:
                q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                         & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
323
324
        # recompute p = softmax(qk, dim=-1).T
        qk = tl.dot(q, k, trans_b=True)
325
        # Trying to combine the two masks seem to make the result wrong
326
327
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
Tri Dao's avatar
Tri Dao committed
328
329
        if IS_CAUSAL:
            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
330
        if BIAS_TYPE != 'none':
331
            tl.debug_barrier()  # Race condition otherwise
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs,
                                   mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_n[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            qk = qk * softmax_scale + bias
347
        # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
348
        # Also wrong for headdim=64.
349
        if not (EVEN_M & EVEN_HEADDIM):
350
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
351
        lse_i = tl.load(LSE + offs_m_curr)
352
353
354
355
        if BIAS_TYPE == 'none':
            p = tl.exp(qk * softmax_scale - lse_i[:, None])
        else:
            p = tl.exp(qk - lse_i[:, None])
Tri Dao's avatar
Tri Dao committed
356
        # compute dv
357
358
359
360
361
        # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
        # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
        # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
        # the output is correct.
        if EVEN_M & EVEN_HEADDIM:
362
            do = tl.load(do_ptrs)
363
364
365
366
        else:
            # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
            do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_d[None, :] < headdim), other=0.0)
367
368
369
370
371
        # if EVEN_M:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs)
        #     else:
        #         do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
372
373
374
375
376
377
        # else:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
        #     else:
        #         do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
        #                                    & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
378
379
        dv += tl.dot(p.to(do.dtype), do, trans_a=True)
        # compute dp = dot(v, do)
380
        # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
381
        # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
382
        # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
383
        if not (EVEN_M & EVEN_HEADDIM):
Tri Dao's avatar
Tri Dao committed
384
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
385
        dp = tl.dot(do, v, trans_b=True)
386
387
388
        # There's a race condition for headdim=48
        if not EVEN_HEADDIM:
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
389
390
391
392
393
394
395
396
397
        # compute ds = p * (dp - delta[:, None])
        # Putting the subtraction after the dp matmul (instead of before) is slightly faster
        Di = tl.load(D + offs_m_curr)
        # Converting ds to q.dtype here reduces register pressure and makes it much faster
        # for BLOCK_HEADDIM=128
        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
        # compute dk = dot(ds.T, q)
        dk += tl.dot(ds, q, trans_a=True)
        # compute dq
398
399
        if not (EVEN_M & EVEN_HEADDIM):  # Otherewise there's a race condition when BIAS_TYPE='matrix'
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
400
        if not ATOMIC_ADD:
401
402
403
404
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                dq = tl.load(dq_ptrs, eviction_policy="evict_last")
                dq += tl.dot(ds, k)
                tl.store(dq_ptrs, dq, eviction_policy="evict_last")
405
            else:
406
407
408
409
410
411
412
413
414
415
416
417
418
                if EVEN_HEADDIM:
                    dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
                                eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
                            eviction_policy="evict_last")
                else:
                    dq = tl.load(dq_ptrs,
                                 mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                                 other=0.0, eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq,
                             mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
419
                             eviction_policy="evict_last")
Tri Dao's avatar
Tri Dao committed
420
421
        else:  # If we're parallelizing across the seqlen_k dimension
            dq = tl.dot(ds, k)
422
423
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                tl.atomic_add(dq_ptrs, dq)
424
            else:
425
426
427
428
429
                if EVEN_HEADDIM:
                    tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
                else:
                    tl.atomic_add(dq_ptrs, dq,
                                  mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
430
431
432
433
        # increment pointers
        dq_ptrs += BLOCK_M * stride_dqm
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_dom
434
435
        if BIAS_TYPE == 'matrix':
            b_ptrs += BLOCK_M * stride_bm
Tri Dao's avatar
Tri Dao committed
436
    # write-back
437
438
    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
439
440
441
    # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.store(dv_ptrs), there's a race condition
    if EVEN_N & EVEN_M:
442
443
444
445
446
447
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv)
            tl.store(dk_ptrs, dk)
        else:
            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
448
    else:
449
450
451
452
453
454
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
            tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
        else:
            tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
            tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
455
456
457
458
459


def init_to_zero(name):
    return lambda nargs: nargs[name].zero_()

460

Tri Dao's avatar
Tri Dao committed
461
462
463
464
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
465
466
467
468
469
470
        # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
        # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
Tri Dao's avatar
Tri Dao committed
471
    ],
472
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
Tri Dao's avatar
Tri Dao committed
473
)
474
475
476
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
477
478
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
479
480
    }
)
Tri Dao's avatar
Tri Dao committed
481
482
@triton.jit
def _bwd_kernel(
483
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
484
485
486
487
488
489
    DO, DQ, DK, DV,
    LSE, D,
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
490
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
491
492
493
494
    stride_dob, stride_doh, stride_dom,
    stride_dqb, stride_dqh, stride_dqm,
    stride_dkb, stride_dkh, stride_dkn,
    stride_dvb, stride_dvh, stride_dvn,
495
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
496
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
497
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
498
499
500
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    SEQUENCE_PARALLEL: tl.constexpr,
501
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
502
503
504
505
506
507
508
509
510
511
512
513
514
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # offset pointers for batch/head
    Q += off_b * stride_qb + off_h * stride_qh
    K += off_b * stride_kb + off_h * stride_kh
    V += off_b * stride_vb + off_h * stride_vh
    DO += off_b * stride_dob + off_h * stride_doh
    DQ += off_b * stride_dqb + off_h * stride_dqh
    DK += off_b * stride_dkb + off_h * stride_dkh
    DV += off_b * stride_dvb + off_h * stride_dvh
515
516
    if BIAS_TYPE != 'none':
        Bias += off_b * stride_bb + off_h * stride_bh
Tri Dao's avatar
Tri Dao committed
517
518
519
520
521
522
523
524
    # pointer to row-wise quantities in value-like data
    D += off_hb * seqlen_q_rounded
    LSE += off_hb * seqlen_q_rounded
    if not SEQUENCE_PARALLEL:
        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
        for start_n in range(0, num_block_n):
            _bwd_kernel_one_col_block(
                start_n,
525
                Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
526
527
                DO, DQ, DK, DV,
                LSE, D,
528
529
530
                softmax_scale,
                stride_qm, stride_kn, stride_vn, stride_bm,
                stride_dom, stride_dqm, stride_dkn, stride_dvn,
531
                seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
532
                ATOMIC_ADD=False,
533
                BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
534
535
                IS_CAUSAL=IS_CAUSAL,
                BLOCK_HEADDIM=BLOCK_HEADDIM,
536
                EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
537
538
539
540
541
542
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
            )
    else:
        start_n = tl.program_id(0)
        _bwd_kernel_one_col_block(
            start_n,
543
            Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
544
545
            DO, DQ, DK, DV,
            LSE, D,
546
547
548
            softmax_scale,
            stride_qm, stride_kn, stride_vn, stride_bm,
            stride_dom, stride_dqm, stride_dkn, stride_dvn,
549
            seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
550
            ATOMIC_ADD=True,
551
            BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
552
553
            IS_CAUSAL=IS_CAUSAL,
            BLOCK_HEADDIM=BLOCK_HEADDIM,
554
            EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
555
556
557
558
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
        )


559
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
560
561
562
563
564
    # shape constraints
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
    assert k.shape == (batch, seqlen_k, nheads, d)
    assert v.shape == (batch, seqlen_k, nheads, d)
565
    assert d <= 128, 'FlashAttention only support head dimensions up to 128'
Tri Dao's avatar
Tri Dao committed
566
567
568
569
    assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
    assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
    assert q.is_cuda and k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        if bias.stride(-1) != 1:
            bias = bias.contiguous()
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
        if bias.shape[:2] == (1, nheads):
            bias = repeat(bias, '1 h ... -> b h ...', b=batch)
        elif bias.shape[:2] == (batch, 1):
            bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
        assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
593
594
595
596
597
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    o = torch.empty_like(q)

598
    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
599
600
601
602
    # BLOCK = 128
    # num_warps = 4 if d <= 64 else 8
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _fwd_kernel[grid](
603
        q, k, v, bias, o,
Tri Dao's avatar
Tri Dao committed
604
605
606
607
608
        lse, tmp,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
609
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
610
        o.stride(0), o.stride(2), o.stride(1),
611
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
612
613
614
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
615
        bias_type, causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
616
617
618
619
620
621
622
        # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
        # num_warps=num_warps,
        # num_stages=1,
    )
    return o, lse, softmax_scale  # softmax_scale could have been updated


623
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
624
625
626
627
628
    # Make sure that the last dimension is contiguous
    if do.stride(-1) != 1:
        do = do.contiguous()
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
629
630
    # assert d in {16, 32, 64, 128}
    assert d <= 128
Tri Dao's avatar
Tri Dao committed
631
632
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    assert lse.shape == (batch, nheads, seqlen_q_rounded)
633
634
    assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
    assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
635
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
Tri Dao's avatar
Tri Dao committed
636
637
638
639
    # dq_accum = torch.zeros_like(q, dtype=torch.float32)
    dq_accum = torch.empty_like(q, dtype=torch.float32)
    delta = torch.empty_like(lse)
    # delta = torch.zeros_like(lse)
640
641

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
642
643
644
645
646
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _bwd_preprocess_do_o_dot[grid](
        o, do, delta,
        o.stride(0), o.stride(2), o.stride(1),
        do.stride(0), do.stride(2), do.stride(1),
647
648
        nheads, seqlen_q, seqlen_q_rounded, d,
        BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
649
650
    )

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        assert bias.stride(-1) == 1
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
        if bias.shape[:2] == (1, nheads):
            bias = repeat(bias, '1 h ... -> b h ...', b=batch)
        elif bias.shape[:2] == (batch, 1):
            bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
        assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
672
673
674
675
676
677
    # BLOCK_M = 128
    # BLOCK_N = 64
    # num_warps = 4
    grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
                    batch * nheads)
    _bwd_kernel[grid](
678
        q, k, v, bias,
Tri Dao's avatar
Tri Dao committed
679
680
681
682
683
684
        do, dq_accum, dk, dv,
        lse, delta,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
685
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
686
687
688
689
        do.stride(0), do.stride(2), do.stride(1),
        dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
        dk.stride(0), dk.stride(2), dk.stride(1),
        dv.stride(0), dv.stride(2), dv.stride(1),
690
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
691
692
693
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
694
        bias_type, causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
695
696
697
698
699
700
701
702
703
704
705
        # SEQUENCE_PARALLEL=False,
        # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        # num_warps=num_warps,
        # num_stages=1,
    )
    dq.copy_(dq_accum)


class FlashAttnQKVPackedFunc(torch.autograd.Function):

    @staticmethod
706
    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
707
708
        """
            qkv: (batch, seqlen, 3, nheads, headdim)
709
710
711
            bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
Tri Dao's avatar
Tri Dao committed
712
713
714
715
716
        """
        # Make sure that the last dimension is contiguous
        if qkv.stride(-1) != 1:
            qkv = qkv.contiguous()
        o, lse, ctx.softmax_scale = _flash_attn_forward(
717
718
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
            softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
719
        )
720
        ctx.save_for_backward(qkv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
721
722
723
724
725
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
726
727
        qkv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
728
729
730
731
732
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dqkv = torch.empty_like(qkv)
            _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
733
734
735
                                 dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dqkv, None, None, None
Tri Dao's avatar
Tri Dao committed
736
737
738
739
740
741
742
743


flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply


class FlashAttnKVPackedFunc(torch.autograd.Function):

    @staticmethod
744
    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
745
        """
746
747
748
749
750
            q: (batch, seqlen_q, nheads, headdim)
            kv: (batch, seqlen_k, 2, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
751
752
753
754
        """
        # Make sure that the last dimension is contiguous
        q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
        o, lse, ctx.softmax_scale = _flash_attn_forward(
755
            q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
756
        )
757
        ctx.save_for_backward(q, kv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
758
759
760
761
762
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
763
764
        q, kv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
765
766
767
768
769
770
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
771
772
773
                                 dq, dkv[:, :, 0], dkv[:, :, 1],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dkv, None, None, None
Tri Dao's avatar
Tri Dao committed
774
775
776
777
778
779
780
781


flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
782
    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
783
        """
784
785
786
787
788
            q: (batch_size, seqlen_q, nheads, headdim)
            k, v: (batch_size, seqlen_k, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
789
790
791
        """
        # Make sure that the last dimension is contiguous
        q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
792
793
794
795
        o, lse, ctx.softmax_scale = _flash_attn_forward(
            q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
        )
        ctx.save_for_backward(q, k, v, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
796
797
798
799
800
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
801
802
        q, k, v, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
803
804
805
806
807
808
809
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
810
811
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dk, dv, None, None, None
Tri Dao's avatar
Tri Dao committed
812
813
814


flash_attn_func = FlashAttnFunc.apply