flash_attn_triton.py 37.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
"""
2
3
*Experimental* implementation of FlashAttention in Triton.

Tri Dao's avatar
Tri Dao committed
4
We use the FlashAttention implementation from Phil Tillet a starting point.
Tri Dao's avatar
Tri Dao committed
5
6
7
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

Changes:
8
- Implement both causal and non-causal attention.
9
- Implement both self-attention and cross-attention.
10
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
11
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
12
- Support attention bias.
13
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
Tri Dao's avatar
Tri Dao committed
14
- Make the backward for d=128 much faster by reducing register spilling.
15
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
Tri Dao's avatar
Tri Dao committed
16
small batch size * nheads.
Tri Dao's avatar
Tri Dao committed
17

18
Caution:
19
20
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
21
- This implementation has only been tested on A100.
22
23
24
25
26
27
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.

Tri Dao's avatar
Tri Dao committed
28
29
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
30
31
32
33
34
- Triton forward is generally faster than CUDA forward, while Triton backward is
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
than CUDA forward + backward.
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton version supports attention bias, while CUDA version doesn't.
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
"""

import math

import torch

41
42
from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
43
44
45
46
import triton
import triton.language as tl


47
48
49
50
51
52
53
54
55
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
#     configs=[
#         triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
#         # This config has a race condition when EVEN_M == False, disabling it for now.
#         # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
#     ],
#     key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
Tri Dao's avatar
Tri Dao committed
56
57
58
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
59
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
60
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
Tri Dao's avatar
Tri Dao committed
61
62
63
64
    }
)
@triton.jit
def _fwd_kernel(
65
    Q, K, V, Bias, Out,
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
71
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
72
    stride_ob, stride_oh, stride_om,
73
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
74
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
75
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
76
77
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
78
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
99
100
101
102
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
Tri Dao's avatar
Tri Dao committed
103
    # initialize pointer to m and l
104
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
105
106
107
108
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
109
110
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
111
    if EVEN_M & EVEN_N:
112
113
114
115
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
116
    else:
117
118
119
120
121
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
122
123
124
125
126
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
Tri Dao's avatar
Tri Dao committed
127
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
128
129
130
131
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
132
        else:
133
134
135
136
137
138
139
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
140
141
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
142
143
        # Trying to combine the two masks seem to make the result wrong
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
Tri Dao's avatar
Tri Dao committed
144
145
146
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if BIAS_TYPE != 'none':
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n,
                                   mask=(offs_m[:, None] < seqlen_q)
                                        & ((start_n + offs_n)[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
            # can then fuse the mult and add into an fma instruction. But if we have bias we need to
            # to multiply with softmax_scale here.
            qk = qk * softmax_scale + bias
            m_ij = tl.maximum(tl.max(qk, 1), lse_i)
            p = tl.exp(qk - m_ij[:, None])
        else:
            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
            p = tl.exp(qk * softmax_scale - m_ij[:, None])
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
178
179
180
181
        l_ij = tl.sum(p, 1)

        # scale acc_o
        acc_o_scale = tl.exp(m_i - m_ij)

        # # -- update output accumulator --
        # BUG: have to store and immediately load
        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)
        acc_o = acc_o * acc_o_scale[:, None]
        # update acc_o
Tri Dao's avatar
Tri Dao committed
182
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
183
184
185
186
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
187
        else:
188
189
190
191
192
193
194
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- update statistics
        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    o_scale = tl.exp(m_i - lse_i)
    # BUG: have to store and immediately load
    tl.store(t_ptrs, o_scale)
    o_scale = tl.load(t_ptrs)
    acc_o = acc_o * o_scale[:, None]
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
212
    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
213
214
    tl.store(lse_ptrs, lse_i)
    # initialize pointers to output
215
216
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
Tri Dao's avatar
Tri Dao committed
217
    if EVEN_M:
218
219
220
221
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
Tri Dao's avatar
Tri Dao committed
222
    else:
223
224
225
226
227
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(out_ptrs, acc_o,
                     mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
228
229
230
231
232
233
234


@triton.jit
def _bwd_preprocess_do_o_dot(
    Out, DO, Delta,
    stride_ob, stride_oh, stride_om,
    stride_dob, stride_doh, stride_dom,
235
    nheads, seqlen_q, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
236
237
238
239
240
241
242
243
244
245
    BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # load
246
247
248
249
    o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
    do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
                 mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@triton.jit
def _bwd_store_dk_dv(
    dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
):
    # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.store(dv_ptrs), there's a race condition
    if EVEN_N & EVEN_M:
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv)
            tl.store(dk_ptrs, dk)
        else:
            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
    else:
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
            tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
        else:
            tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
            tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))


Tri Dao's avatar
Tri Dao committed
278
279
280
@triton.jit
def _bwd_kernel_one_col_block(
    start_n,
281
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
282
283
    DO, DQ, DK, DV,
    LSE, D,
284
285
286
    softmax_scale,
    stride_qm, stride_kn, stride_vn, stride_bm,
    stride_dom, stride_dqm, stride_dkn, stride_dvn,
287
    seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
288
    ATOMIC_ADD: tl.constexpr,
289
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
290
291
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
292
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
293
294
295
296
297
298
299
300
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
    begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
    # initialize row/col offsets
    offs_qm = begin_m + tl.arange(0, BLOCK_M)
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
301
    offs_d = tl.arange(0, BLOCK_HEADDIM)
Tri Dao's avatar
Tri Dao committed
302
    # initialize pointers to value-like data
303
304
305
306
307
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
308
309
310
311
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
312
    # initialize dv and dk
Tri Dao's avatar
Tri Dao committed
313
314
    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
315
316
317
318
319
320
321
322
323
324
    # There seems to be some problem with Triton pipelining that makes results wrong for
    # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
    # may have zero step, and pipelining with the bias matrix could screw it up.
    # So we just exit early.
    if begin_m >= seqlen_q:
        dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
        dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
        _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
                         EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
        return
Tri Dao's avatar
Tri Dao committed
325
    # k and v stay in SRAM throughout
326
327
    # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.load(k_ptrs), we get the wrong output!
328
    if EVEN_N & EVEN_M:
329
330
331
332
333
334
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs)
            v = tl.load(v_ptrs)
        else:
            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
335
    else:
336
337
338
339
340
341
342
343
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
            v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
        else:
            k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
            v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
344
345
346
347
348
349
    # loop over rows
    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        offs_m_curr = start_m + offs_m
        # load q, k, v, do on-chip
350
351
352
        # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
        if EVEN_M & EVEN_HEADDIM:
            q = tl.load(q_ptrs)
353
        else:
354
355
356
357
358
            if EVEN_HEADDIM:
                q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
            else:
                q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                         & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
359
360
        # recompute p = softmax(qk, dim=-1).T
        qk = tl.dot(q, k, trans_b=True)
361
        # Trying to combine the two masks seem to make the result wrong
362
363
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
Tri Dao's avatar
Tri Dao committed
364
365
        if IS_CAUSAL:
            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
366
        if BIAS_TYPE != 'none':
367
            tl.debug_barrier()  # Race condition otherwise
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs,
                                   mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_n[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            qk = qk * softmax_scale + bias
383
        # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
384
        # Also wrong for headdim=64.
385
        if not (EVEN_M & EVEN_HEADDIM):
386
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
387
        lse_i = tl.load(LSE + offs_m_curr)
388
389
390
391
        if BIAS_TYPE == 'none':
            p = tl.exp(qk * softmax_scale - lse_i[:, None])
        else:
            p = tl.exp(qk - lse_i[:, None])
Tri Dao's avatar
Tri Dao committed
392
        # compute dv
393
394
395
396
397
        # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
        # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
        # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
        # the output is correct.
        if EVEN_M & EVEN_HEADDIM:
398
            do = tl.load(do_ptrs)
399
400
401
402
        else:
            # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
            do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_d[None, :] < headdim), other=0.0)
403
404
405
406
407
        # if EVEN_M:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs)
        #     else:
        #         do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
408
409
410
411
412
413
        # else:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
        #     else:
        #         do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
        #                                    & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
414
415
        dv += tl.dot(p.to(do.dtype), do, trans_a=True)
        # compute dp = dot(v, do)
416
        # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
417
        # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
418
        # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
419
        if not (EVEN_M & EVEN_HEADDIM):
Tri Dao's avatar
Tri Dao committed
420
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
421
        dp = tl.dot(do, v, trans_b=True)
422
423
424
        # There's a race condition for headdim=48
        if not EVEN_HEADDIM:
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
425
426
427
428
429
430
431
432
433
        # compute ds = p * (dp - delta[:, None])
        # Putting the subtraction after the dp matmul (instead of before) is slightly faster
        Di = tl.load(D + offs_m_curr)
        # Converting ds to q.dtype here reduces register pressure and makes it much faster
        # for BLOCK_HEADDIM=128
        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
        # compute dk = dot(ds.T, q)
        dk += tl.dot(ds, q, trans_a=True)
        # compute dq
434
435
        if not (EVEN_M & EVEN_HEADDIM):  # Otherewise there's a race condition when BIAS_TYPE='matrix'
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
436
        if not ATOMIC_ADD:
437
438
439
440
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                dq = tl.load(dq_ptrs, eviction_policy="evict_last")
                dq += tl.dot(ds, k)
                tl.store(dq_ptrs, dq, eviction_policy="evict_last")
441
            else:
442
443
444
445
446
447
448
449
450
451
452
453
454
                if EVEN_HEADDIM:
                    dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
                                eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
                            eviction_policy="evict_last")
                else:
                    dq = tl.load(dq_ptrs,
                                 mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                                 other=0.0, eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq,
                             mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
455
                             eviction_policy="evict_last")
Tri Dao's avatar
Tri Dao committed
456
457
        else:  # If we're parallelizing across the seqlen_k dimension
            dq = tl.dot(ds, k)
458
459
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                tl.atomic_add(dq_ptrs, dq)
460
            else:
461
462
463
464
465
                if EVEN_HEADDIM:
                    tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
                else:
                    tl.atomic_add(dq_ptrs, dq,
                                  mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
466
467
468
469
        # increment pointers
        dq_ptrs += BLOCK_M * stride_dqm
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_dom
470
471
        if BIAS_TYPE == 'matrix':
            b_ptrs += BLOCK_M * stride_bm
Tri Dao's avatar
Tri Dao committed
472
    # write-back
473
474
    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
475
476
    _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
                     EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
Tri Dao's avatar
Tri Dao committed
477
478
479
480
481


def init_to_zero(name):
    return lambda nargs: nargs[name].zero_()

482

Tri Dao's avatar
Tri Dao committed
483
484
485
486
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
487
488
489
490
491
492
        # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
        # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
Tri Dao's avatar
Tri Dao committed
493
    ],
494
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
Tri Dao's avatar
Tri Dao committed
495
)
496
497
498
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
499
500
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
501
502
    }
)
Tri Dao's avatar
Tri Dao committed
503
504
@triton.jit
def _bwd_kernel(
505
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
506
507
508
509
510
511
    DO, DQ, DK, DV,
    LSE, D,
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
512
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
513
514
515
516
    stride_dob, stride_doh, stride_dom,
    stride_dqb, stride_dqh, stride_dqm,
    stride_dkb, stride_dkh, stride_dkn,
    stride_dvb, stride_dvh, stride_dvn,
517
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
518
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
519
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
520
521
522
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    SEQUENCE_PARALLEL: tl.constexpr,
523
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
524
525
526
527
528
529
530
531
532
533
534
535
536
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # offset pointers for batch/head
    Q += off_b * stride_qb + off_h * stride_qh
    K += off_b * stride_kb + off_h * stride_kh
    V += off_b * stride_vb + off_h * stride_vh
    DO += off_b * stride_dob + off_h * stride_doh
    DQ += off_b * stride_dqb + off_h * stride_dqh
    DK += off_b * stride_dkb + off_h * stride_dkh
    DV += off_b * stride_dvb + off_h * stride_dvh
537
538
    if BIAS_TYPE != 'none':
        Bias += off_b * stride_bb + off_h * stride_bh
Tri Dao's avatar
Tri Dao committed
539
540
541
542
543
544
545
546
    # pointer to row-wise quantities in value-like data
    D += off_hb * seqlen_q_rounded
    LSE += off_hb * seqlen_q_rounded
    if not SEQUENCE_PARALLEL:
        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
        for start_n in range(0, num_block_n):
            _bwd_kernel_one_col_block(
                start_n,
547
                Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
548
549
                DO, DQ, DK, DV,
                LSE, D,
550
551
552
                softmax_scale,
                stride_qm, stride_kn, stride_vn, stride_bm,
                stride_dom, stride_dqm, stride_dkn, stride_dvn,
553
                seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
554
                ATOMIC_ADD=False,
555
                BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
556
557
                IS_CAUSAL=IS_CAUSAL,
                BLOCK_HEADDIM=BLOCK_HEADDIM,
558
                EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
564
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
            )
    else:
        start_n = tl.program_id(0)
        _bwd_kernel_one_col_block(
            start_n,
565
            Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
566
567
            DO, DQ, DK, DV,
            LSE, D,
568
569
570
            softmax_scale,
            stride_qm, stride_kn, stride_vn, stride_bm,
            stride_dom, stride_dqm, stride_dkn, stride_dvn,
571
            seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
572
            ATOMIC_ADD=True,
573
            BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
574
575
            IS_CAUSAL=IS_CAUSAL,
            BLOCK_HEADDIM=BLOCK_HEADDIM,
576
            EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
577
578
579
580
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
        )


581
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
    # shape constraints
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
    assert k.shape == (batch, seqlen_k, nheads, d)
    assert v.shape == (batch, seqlen_k, nheads, d)
587
    assert d <= 128, 'FlashAttention only support head dimensions up to 128'
Tri Dao's avatar
Tri Dao committed
588
589
590
591
    assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
    assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
    assert q.is_cuda and k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614

    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        if bias.stride(-1) != 1:
            bias = bias.contiguous()
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
        if bias.shape[:2] == (1, nheads):
            bias = repeat(bias, '1 h ... -> b h ...', b=batch)
        elif bias.shape[:2] == (batch, 1):
            bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
        assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    o = torch.empty_like(q)

620
    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
621
622
    BLOCK = 128
    num_warps = 4 if d <= 64 else 8
Tri Dao's avatar
Tri Dao committed
623
624
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _fwd_kernel[grid](
625
        q, k, v, bias, o,
Tri Dao's avatar
Tri Dao committed
626
627
628
629
630
        lse, tmp,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
631
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
632
        o.stride(0), o.stride(2), o.stride(1),
633
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
634
635
636
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
637
        bias_type, causal, BLOCK_HEADDIM,
638
639
640
        BLOCK_M=BLOCK, BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
Tri Dao's avatar
Tri Dao committed
641
642
643
644
    )
    return o, lse, softmax_scale  # softmax_scale could have been updated


645
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
646
647
648
649
650
    # Make sure that the last dimension is contiguous
    if do.stride(-1) != 1:
        do = do.contiguous()
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
651
652
    # assert d in {16, 32, 64, 128}
    assert d <= 128
Tri Dao's avatar
Tri Dao committed
653
654
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    assert lse.shape == (batch, nheads, seqlen_q_rounded)
655
656
    assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
    assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
657
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
Tri Dao's avatar
Tri Dao committed
658
659
660
661
    # dq_accum = torch.zeros_like(q, dtype=torch.float32)
    dq_accum = torch.empty_like(q, dtype=torch.float32)
    delta = torch.empty_like(lse)
    # delta = torch.zeros_like(lse)
662
663

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
664
665
666
667
668
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _bwd_preprocess_do_o_dot[grid](
        o, do, delta,
        o.stride(0), o.stride(2), o.stride(1),
        do.stride(0), do.stride(2), do.stride(1),
669
670
        nheads, seqlen_q, seqlen_q_rounded, d,
        BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
671
672
    )

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        assert bias.stride(-1) == 1
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
        if bias.shape[:2] == (1, nheads):
            bias = repeat(bias, '1 h ... -> b h ...', b=batch)
        elif bias.shape[:2] == (batch, 1):
            bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
        assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
694
695
696
697
698
699
    # BLOCK_M = 128
    # BLOCK_N = 64
    # num_warps = 4
    grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
                    batch * nheads)
    _bwd_kernel[grid](
700
        q, k, v, bias,
Tri Dao's avatar
Tri Dao committed
701
702
703
704
705
706
        do, dq_accum, dk, dv,
        lse, delta,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
707
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
708
709
710
711
        do.stride(0), do.stride(2), do.stride(1),
        dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
        dk.stride(0), dk.stride(2), dk.stride(1),
        dv.stride(0), dv.stride(2), dv.stride(1),
712
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
713
714
715
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
716
        bias_type, causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
717
718
719
720
721
722
723
724
725
726
727
        # SEQUENCE_PARALLEL=False,
        # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        # num_warps=num_warps,
        # num_stages=1,
    )
    dq.copy_(dq_accum)


class FlashAttnQKVPackedFunc(torch.autograd.Function):

    @staticmethod
728
    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
729
730
        """
            qkv: (batch, seqlen, 3, nheads, headdim)
731
732
733
            bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
Tri Dao's avatar
Tri Dao committed
734
735
736
737
738
        """
        # Make sure that the last dimension is contiguous
        if qkv.stride(-1) != 1:
            qkv = qkv.contiguous()
        o, lse, ctx.softmax_scale = _flash_attn_forward(
739
740
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
            softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
741
        )
742
        ctx.save_for_backward(qkv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
743
744
745
746
747
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
748
749
        qkv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
750
751
752
753
754
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dqkv = torch.empty_like(qkv)
            _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
755
756
757
                                 dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dqkv, None, None, None
Tri Dao's avatar
Tri Dao committed
758
759
760
761
762
763
764
765


flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply


class FlashAttnKVPackedFunc(torch.autograd.Function):

    @staticmethod
766
    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
767
        """
768
769
770
771
772
            q: (batch, seqlen_q, nheads, headdim)
            kv: (batch, seqlen_k, 2, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
773
774
775
776
        """
        # Make sure that the last dimension is contiguous
        q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
        o, lse, ctx.softmax_scale = _flash_attn_forward(
777
            q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
778
        )
779
        ctx.save_for_backward(q, kv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
780
781
782
783
784
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
785
786
        q, kv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
787
788
789
790
791
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
792
            _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
793
794
795
                                 dq, dkv[:, :, 0], dkv[:, :, 1],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dkv, None, None, None
Tri Dao's avatar
Tri Dao committed
796
797
798
799
800
801
802
803


flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
804
    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
805
        """
806
807
808
809
810
            q: (batch_size, seqlen_q, nheads, headdim)
            k, v: (batch_size, seqlen_k, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
811
812
813
        """
        # Make sure that the last dimension is contiguous
        q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
814
815
816
817
        o, lse, ctx.softmax_scale = _flash_attn_forward(
            q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
        )
        ctx.save_for_backward(q, k, v, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
818
819
820
821
822
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
823
824
        q, k, v, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
825
826
827
828
829
830
831
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
832
833
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dk, dv, None, None, None
Tri Dao's avatar
Tri Dao committed
834
835
836


flash_attn_func = FlashAttnFunc.apply