flash_attn_triton.py 36.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
"""
2
3
*Experimental* implementation of FlashAttention in Triton.

Tri Dao's avatar
Tri Dao committed
4
We use the FlashAttention implementation from Phil Tillet a starting point.
Tri Dao's avatar
Tri Dao committed
5
6
7
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

Changes:
8
- Implement both causal and non-causal attention.
9
- Implement both self-attention and cross-attention.
10
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
11
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
12
- Support attention bias.
13
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
Tri Dao's avatar
Tri Dao committed
14
- Make the backward for d=128 much faster by reducing register spilling.
15
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
Tri Dao's avatar
Tri Dao committed
16
small batch size * nheads.
Tri Dao's avatar
Tri Dao committed
17

18
Caution:
19
20
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
21
- This implementation has only been tested on A100.
22
23
24
25
26
27
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.

Tri Dao's avatar
Tri Dao committed
28
29
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
30
31
32
33
34
- Triton forward is generally faster than CUDA forward, while Triton backward is
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
than CUDA forward + backward.
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton version supports attention bias, while CUDA version doesn't.
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
"""

import math

import torch

import triton
import triton.language as tl


45
46
47
48
49
50
51
52
53
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
#     configs=[
#         triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
#         # This config has a race condition when EVEN_M == False, disabling it for now.
#         # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
#     ],
#     key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
Tri Dao's avatar
Tri Dao committed
54
55
56
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
57
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
58
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
Tri Dao's avatar
Tri Dao committed
59
60
61
62
    }
)
@triton.jit
def _fwd_kernel(
63
    Q, K, V, Bias, Out,
Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
69
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
70
    stride_ob, stride_oh, stride_om,
71
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
72
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
73
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
74
75
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
76
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
97
98
99
100
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
Tri Dao's avatar
Tri Dao committed
101
    # initialize pointer to m and l
102
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
103
104
105
106
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
107
108
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
109
    if EVEN_M & EVEN_N:
110
111
112
113
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
114
    else:
115
116
117
118
119
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
Tri Dao's avatar
Tri Dao committed
125
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
126
127
128
129
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
130
        else:
131
132
133
134
135
136
137
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
138
139
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
140
141
        # Trying to combine the two masks seem to make the result wrong
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
Tri Dao's avatar
Tri Dao committed
142
143
144
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        if BIAS_TYPE != 'none':
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs + start_n,
                                   mask=(offs_m[:, None] < seqlen_q)
                                        & ((start_n + offs_n)[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
            # can then fuse the mult and add into an fma instruction. But if we have bias we need to
            # to multiply with softmax_scale here.
            qk = qk * softmax_scale + bias
            m_ij = tl.maximum(tl.max(qk, 1), lse_i)
            p = tl.exp(qk - m_ij[:, None])
        else:
            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
            p = tl.exp(qk * softmax_scale - m_ij[:, None])
Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
176
177
178
179
        l_ij = tl.sum(p, 1)

        # scale acc_o
        acc_o_scale = tl.exp(m_i - m_ij)

        # # -- update output accumulator --
        # BUG: have to store and immediately load
        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)
        acc_o = acc_o * acc_o_scale[:, None]
        # update acc_o
Tri Dao's avatar
Tri Dao committed
180
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
181
182
183
184
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
Tri Dao's avatar
Tri Dao committed
185
        else:
186
187
188
189
190
191
192
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
Tri Dao's avatar
Tri Dao committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- update statistics
        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    o_scale = tl.exp(m_i - lse_i)
    # BUG: have to store and immediately load
    tl.store(t_ptrs, o_scale)
    o_scale = tl.load(t_ptrs)
    acc_o = acc_o * o_scale[:, None]
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
210
    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
Tri Dao's avatar
Tri Dao committed
211
212
    tl.store(lse_ptrs, lse_i)
    # initialize pointers to output
213
214
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
Tri Dao's avatar
Tri Dao committed
215
    if EVEN_M:
216
217
218
219
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
Tri Dao's avatar
Tri Dao committed
220
    else:
221
222
223
224
225
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(out_ptrs, acc_o,
                     mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
226
227
228
229
230
231
232


@triton.jit
def _bwd_preprocess_do_o_dot(
    Out, DO, Delta,
    stride_ob, stride_oh, stride_om,
    stride_dob, stride_doh, stride_dom,
233
    nheads, seqlen_q, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
234
235
236
237
238
239
240
241
242
243
    BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # load
244
245
246
247
    o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
    do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
                 mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
Tri Dao's avatar
Tri Dao committed
248
249
250
251
252
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)


253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
@triton.jit
def _bwd_store_dk_dv(
    dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
):
    # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.store(dv_ptrs), there's a race condition
    if EVEN_N & EVEN_M:
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv)
            tl.store(dk_ptrs, dk)
        else:
            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
    else:
        if EVEN_HEADDIM:
            tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
            tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
        else:
            tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
            tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))


Tri Dao's avatar
Tri Dao committed
276
277
278
@triton.jit
def _bwd_kernel_one_col_block(
    start_n,
279
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
280
281
    DO, DQ, DK, DV,
    LSE, D,
282
283
284
    softmax_scale,
    stride_qm, stride_kn, stride_vn, stride_bm,
    stride_dom, stride_dqm, stride_dkn, stride_dvn,
285
    seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
286
    ATOMIC_ADD: tl.constexpr,
287
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
288
289
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
290
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
291
292
293
294
295
296
297
298
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
    begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
    # initialize row/col offsets
    offs_qm = begin_m + tl.arange(0, BLOCK_M)
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
299
    offs_d = tl.arange(0, BLOCK_HEADDIM)
Tri Dao's avatar
Tri Dao committed
300
    # initialize pointers to value-like data
301
302
303
304
305
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
306
307
308
309
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
310
    # initialize dv and dk
Tri Dao's avatar
Tri Dao committed
311
312
    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
313
314
315
316
317
318
319
320
321
322
    # There seems to be some problem with Triton pipelining that makes results wrong for
    # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
    # may have zero step, and pipelining with the bias matrix could screw it up.
    # So we just exit early.
    if begin_m >= seqlen_q:
        dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
        dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
        _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
                         EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
        return
Tri Dao's avatar
Tri Dao committed
323
    # k and v stay in SRAM throughout
324
325
    # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
    # if we just call tl.load(k_ptrs), we get the wrong output!
326
    if EVEN_N & EVEN_M:
327
328
329
330
331
332
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs)
            v = tl.load(v_ptrs)
        else:
            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
333
    else:
334
335
336
337
338
339
340
341
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
            v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
        else:
            k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
            v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
Tri Dao's avatar
Tri Dao committed
342
343
344
345
346
347
    # loop over rows
    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        offs_m_curr = start_m + offs_m
        # load q, k, v, do on-chip
348
349
350
        # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
        if EVEN_M & EVEN_HEADDIM:
            q = tl.load(q_ptrs)
351
        else:
352
353
354
355
356
            if EVEN_HEADDIM:
                q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
            else:
                q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                         & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
357
358
        # recompute p = softmax(qk, dim=-1).T
        qk = tl.dot(q, k, trans_b=True)
359
        # Trying to combine the two masks seem to make the result wrong
360
361
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
Tri Dao's avatar
Tri Dao committed
362
363
        if IS_CAUSAL:
            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
364
        if BIAS_TYPE != 'none':
365
            tl.debug_barrier()  # Race condition otherwise
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
            if BIAS_TYPE == 'vector':
                if EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == 'matrix':
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs).to(tl.float32)
                else:
                    bias = tl.load(b_ptrs,
                                   mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_n[None, :] < seqlen_k),
                                   other=0.0).to(tl.float32)
            qk = qk * softmax_scale + bias
381
        # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
382
        # Also wrong for headdim=64.
383
        if not (EVEN_M & EVEN_HEADDIM):
384
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
385
        lse_i = tl.load(LSE + offs_m_curr)
386
387
388
389
        if BIAS_TYPE == 'none':
            p = tl.exp(qk * softmax_scale - lse_i[:, None])
        else:
            p = tl.exp(qk - lse_i[:, None])
Tri Dao's avatar
Tri Dao committed
390
        # compute dv
391
392
393
394
395
        # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
        # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
        # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
        # the output is correct.
        if EVEN_M & EVEN_HEADDIM:
396
            do = tl.load(do_ptrs)
397
398
399
400
        else:
            # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
            do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
                                        & (offs_d[None, :] < headdim), other=0.0)
401
402
403
404
405
        # if EVEN_M:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs)
        #     else:
        #         do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
406
407
408
409
410
411
        # else:
        #     if EVEN_HEADDIM:
        #         do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
        #     else:
        #         do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
        #                                    & (offs_d[None, :] < headdim), other=0.0)
Tri Dao's avatar
Tri Dao committed
412
413
        dv += tl.dot(p.to(do.dtype), do, trans_a=True)
        # compute dp = dot(v, do)
414
        # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
415
        # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
416
        # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
417
        if not (EVEN_M & EVEN_HEADDIM):
Tri Dao's avatar
Tri Dao committed
418
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
419
        dp = tl.dot(do, v, trans_b=True)
420
421
422
        # There's a race condition for headdim=48
        if not EVEN_HEADDIM:
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
423
424
425
426
427
428
429
430
431
        # compute ds = p * (dp - delta[:, None])
        # Putting the subtraction after the dp matmul (instead of before) is slightly faster
        Di = tl.load(D + offs_m_curr)
        # Converting ds to q.dtype here reduces register pressure and makes it much faster
        # for BLOCK_HEADDIM=128
        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
        # compute dk = dot(ds.T, q)
        dk += tl.dot(ds, q, trans_a=True)
        # compute dq
432
433
        if not (EVEN_M & EVEN_HEADDIM):  # Otherewise there's a race condition when BIAS_TYPE='matrix'
            tl.debug_barrier()
Tri Dao's avatar
Tri Dao committed
434
        if not ATOMIC_ADD:
435
436
437
438
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                dq = tl.load(dq_ptrs, eviction_policy="evict_last")
                dq += tl.dot(ds, k)
                tl.store(dq_ptrs, dq, eviction_policy="evict_last")
439
            else:
440
441
442
443
444
445
446
447
448
449
450
451
452
                if EVEN_HEADDIM:
                    dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
                                eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
                            eviction_policy="evict_last")
                else:
                    dq = tl.load(dq_ptrs,
                                 mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                                 other=0.0, eviction_policy="evict_last")
                    dq += tl.dot(ds, k)
                    tl.store(dq_ptrs, dq,
                             mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
453
                             eviction_policy="evict_last")
Tri Dao's avatar
Tri Dao committed
454
455
        else:  # If we're parallelizing across the seqlen_k dimension
            dq = tl.dot(ds, k)
456
457
            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M
                tl.atomic_add(dq_ptrs, dq)
458
            else:
459
460
461
462
463
                if EVEN_HEADDIM:
                    tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
                else:
                    tl.atomic_add(dq_ptrs, dq,
                                  mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
Tri Dao's avatar
Tri Dao committed
464
465
466
467
        # increment pointers
        dq_ptrs += BLOCK_M * stride_dqm
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_dom
468
469
        if BIAS_TYPE == 'matrix':
            b_ptrs += BLOCK_M * stride_bm
Tri Dao's avatar
Tri Dao committed
470
    # write-back
471
472
    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
473
474
    _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
                     EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
Tri Dao's avatar
Tri Dao committed
475
476
477
478
479


def init_to_zero(name):
    return lambda nargs: nargs[name].zero_()

480

Tri Dao's avatar
Tri Dao committed
481
482
483
484
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
485
486
487
488
489
490
        # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
        # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
        # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
Tri Dao's avatar
Tri Dao committed
491
    ],
492
    key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
Tri Dao's avatar
Tri Dao committed
493
)
494
495
496
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
497
498
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
499
500
    }
)
Tri Dao's avatar
Tri Dao committed
501
502
@triton.jit
def _bwd_kernel(
503
    Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
504
505
506
507
508
509
    DO, DQ, DK, DV,
    LSE, D,
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
510
    stride_bb, stride_bh, stride_bm,
Tri Dao's avatar
Tri Dao committed
511
512
513
514
    stride_dob, stride_doh, stride_dom,
    stride_dqb, stride_dqh, stride_dqm,
    stride_dkb, stride_dkh, stride_dkn,
    stride_dvb, stride_dvh, stride_dvn,
515
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
Tri Dao's avatar
Tri Dao committed
516
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
517
    BIAS_TYPE: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
518
519
520
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    SEQUENCE_PARALLEL: tl.constexpr,
521
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
Tri Dao's avatar
Tri Dao committed
522
523
524
525
526
527
528
529
530
531
532
533
534
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # offset pointers for batch/head
    Q += off_b * stride_qb + off_h * stride_qh
    K += off_b * stride_kb + off_h * stride_kh
    V += off_b * stride_vb + off_h * stride_vh
    DO += off_b * stride_dob + off_h * stride_doh
    DQ += off_b * stride_dqb + off_h * stride_dqh
    DK += off_b * stride_dkb + off_h * stride_dkh
    DV += off_b * stride_dvb + off_h * stride_dvh
535
536
    if BIAS_TYPE != 'none':
        Bias += off_b * stride_bb + off_h * stride_bh
Tri Dao's avatar
Tri Dao committed
537
538
539
540
541
542
543
544
    # pointer to row-wise quantities in value-like data
    D += off_hb * seqlen_q_rounded
    LSE += off_hb * seqlen_q_rounded
    if not SEQUENCE_PARALLEL:
        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
        for start_n in range(0, num_block_n):
            _bwd_kernel_one_col_block(
                start_n,
545
                Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
546
547
                DO, DQ, DK, DV,
                LSE, D,
548
549
550
                softmax_scale,
                stride_qm, stride_kn, stride_vn, stride_bm,
                stride_dom, stride_dqm, stride_dkn, stride_dvn,
551
                seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
552
                ATOMIC_ADD=False,
553
                BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
554
555
                IS_CAUSAL=IS_CAUSAL,
                BLOCK_HEADDIM=BLOCK_HEADDIM,
556
                EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
557
558
559
560
561
562
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
            )
    else:
        start_n = tl.program_id(0)
        _bwd_kernel_one_col_block(
            start_n,
563
            Q, K, V, Bias,
Tri Dao's avatar
Tri Dao committed
564
565
            DO, DQ, DK, DV,
            LSE, D,
566
567
568
            softmax_scale,
            stride_qm, stride_kn, stride_vn, stride_bm,
            stride_dom, stride_dqm, stride_dkn, stride_dvn,
569
            seqlen_q, seqlen_k, headdim,
Tri Dao's avatar
Tri Dao committed
570
            ATOMIC_ADD=True,
571
            BIAS_TYPE=BIAS_TYPE,
Tri Dao's avatar
Tri Dao committed
572
573
            IS_CAUSAL=IS_CAUSAL,
            BLOCK_HEADDIM=BLOCK_HEADDIM,
574
            EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
Tri Dao's avatar
Tri Dao committed
575
576
577
578
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
        )


579
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
580
581
582
583
584
    # shape constraints
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
    assert k.shape == (batch, seqlen_k, nheads, d)
    assert v.shape == (batch, seqlen_k, nheads, d)
585
    assert d <= 128, 'FlashAttention only support head dimensions up to 128'
Tri Dao's avatar
Tri Dao committed
586
587
588
589
    assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
    assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
    assert q.is_cuda and k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        if bias.stride(-1) != 1:
            bias = bias.contiguous()
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
606
        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
607
608
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
609
610
611
612
613
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    o = torch.empty_like(q)

614
    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
615
616
    BLOCK = 128
    num_warps = 4 if d <= 64 else 8
Tri Dao's avatar
Tri Dao committed
617
618
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _fwd_kernel[grid](
619
        q, k, v, bias, o,
Tri Dao's avatar
Tri Dao committed
620
621
622
623
624
        lse, tmp,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
625
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
626
        o.stride(0), o.stride(2), o.stride(1),
627
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
628
629
630
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
631
        bias_type, causal, BLOCK_HEADDIM,
632
633
634
        BLOCK_M=BLOCK, BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
Tri Dao's avatar
Tri Dao committed
635
636
637
638
    )
    return o, lse, softmax_scale  # softmax_scale could have been updated


639
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
640
641
642
643
644
    # Make sure that the last dimension is contiguous
    if do.stride(-1) != 1:
        do = do.contiguous()
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape
645
646
    # assert d in {16, 32, 64, 128}
    assert d <= 128
Tri Dao's avatar
Tri Dao committed
647
648
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    assert lse.shape == (batch, nheads, seqlen_q_rounded)
649
650
    assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
    assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
651
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
Tri Dao's avatar
Tri Dao committed
652
653
654
655
    # dq_accum = torch.zeros_like(q, dtype=torch.float32)
    dq_accum = torch.empty_like(q, dtype=torch.float32)
    delta = torch.empty_like(lse)
    # delta = torch.zeros_like(lse)
656
657

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
Tri Dao's avatar
Tri Dao committed
658
659
660
661
662
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _bwd_preprocess_do_o_dot[grid](
        o, do, delta,
        o.stride(0), o.stride(2), o.stride(1),
        do.stride(0), do.stride(2), do.stride(1),
663
664
        nheads, seqlen_q, seqlen_q_rounded, d,
        BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
665
666
    )

667
668
669
670
671
672
673
674
675
676
677
678
679
680
    has_bias = bias is not None
    bias_type = 'none'
    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda
        assert bias.dim() == 4
        assert bias.stride(-1) == 1
        if bias.shape[2:] == (1, seqlen_k):
            bias_type = 'vector'
        elif bias.shape[2:] == (seqlen_q, seqlen_k):
            bias_type = 'matrix'
        else:
            raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
                               ' or (seqlen_q, seqlen_k)')
681
        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
682
683
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

Tri Dao's avatar
Tri Dao committed
684
685
686
687
688
689
    # BLOCK_M = 128
    # BLOCK_N = 64
    # num_warps = 4
    grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
                    batch * nheads)
    _bwd_kernel[grid](
690
        q, k, v, bias,
Tri Dao's avatar
Tri Dao committed
691
692
693
694
695
696
        do, dq_accum, dk, dv,
        lse, delta,
        softmax_scale,
        q.stride(0), q.stride(2), q.stride(1),
        k.stride(0), k.stride(2), k.stride(1),
        v.stride(0), v.stride(2), v.stride(1),
697
        *bias_strides,
Tri Dao's avatar
Tri Dao committed
698
699
700
701
        do.stride(0), do.stride(2), do.stride(1),
        dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
        dk.stride(0), dk.stride(2), dk.stride(1),
        dv.stride(0), dv.stride(2), dv.stride(1),
702
        nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
Tri Dao's avatar
Tri Dao committed
703
704
705
        seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
        # Can't use kwargs here because triton autotune expects key to be args, not kwargs
        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
706
        bias_type, causal, BLOCK_HEADDIM,
Tri Dao's avatar
Tri Dao committed
707
708
709
710
711
712
713
714
715
716
717
        # SEQUENCE_PARALLEL=False,
        # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        # num_warps=num_warps,
        # num_stages=1,
    )
    dq.copy_(dq_accum)


class FlashAttnQKVPackedFunc(torch.autograd.Function):

    @staticmethod
718
    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
719
720
        """
            qkv: (batch, seqlen, 3, nheads, headdim)
721
722
723
            bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
Tri Dao's avatar
Tri Dao committed
724
725
726
727
728
        """
        # Make sure that the last dimension is contiguous
        if qkv.stride(-1) != 1:
            qkv = qkv.contiguous()
        o, lse, ctx.softmax_scale = _flash_attn_forward(
729
730
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
            softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
731
        )
732
        ctx.save_for_backward(qkv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
733
734
735
736
737
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
738
739
        qkv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
740
741
742
743
744
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dqkv = torch.empty_like(qkv)
            _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
745
746
747
                                 dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dqkv, None, None, None
Tri Dao's avatar
Tri Dao committed
748
749
750
751
752
753
754
755


flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply


class FlashAttnKVPackedFunc(torch.autograd.Function):

    @staticmethod
756
    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
757
        """
758
759
760
761
762
            q: (batch, seqlen_q, nheads, headdim)
            kv: (batch, seqlen_k, 2, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
763
764
765
766
        """
        # Make sure that the last dimension is contiguous
        q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
        o, lse, ctx.softmax_scale = _flash_attn_forward(
767
            q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
Tri Dao's avatar
Tri Dao committed
768
        )
769
        ctx.save_for_backward(q, kv, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
770
771
772
773
774
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
775
776
        q, kv, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
777
778
779
780
781
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
782
            _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
783
784
785
                                 dq, dkv[:, :, 0], dkv[:, :, 1],
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dkv, None, None, None
Tri Dao's avatar
Tri Dao committed
786
787
788
789
790
791
792
793


flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
794
    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
Tri Dao's avatar
Tri Dao committed
795
        """
796
797
798
799
800
            q: (batch_size, seqlen_q, nheads, headdim)
            k, v: (batch_size, seqlen_k, nheads, headdim)
            bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
                For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
                ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
801
802
803
        """
        # Make sure that the last dimension is contiguous
        q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
804
805
806
807
        o, lse, ctx.softmax_scale = _flash_attn_forward(
            q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
        )
        ctx.save_for_backward(q, k, v, o, lse, bias)
Tri Dao's avatar
Tri Dao committed
808
809
810
811
812
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
813
814
        q, k, v, o, lse, bias = ctx.saved_tensors
        assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
Tri Dao's avatar
Tri Dao committed
815
816
817
818
819
820
821
        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
        with torch.inference_mode():
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
822
823
                                 bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
        return dq, dk, dv, None, None, None
Tri Dao's avatar
Tri Dao committed
824
825
826


flash_attn_func = FlashAttnFunc.apply