flash_attn_interface.py 25.6 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
import torch
import torch.nn as nn

Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import flash_attn_2_cuda as flash_attn_cuda
from einops import rearrange


def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
40
41
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
42
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
43
44
        q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
    )
45
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
46
47
48
49


def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                               dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
50
51
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
52
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
53
54
        q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
        softmax_scale, False, causal, return_softmax, None
Tri Dao's avatar
Tri Dao committed
55
56
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
57
    #     breakpoint()
58
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
59
60


Tri Dao's avatar
Tri Dao committed
61
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
62
                         dropout_p, softmax_scale, causal, rng_state=None):
Tri Dao's avatar
Tri Dao committed
63
64
65
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
66
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
67
68
        dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
        softmax_scale, causal, None, rng_state
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
    )
    return dq, dk, dv, softmax_d


def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
75
                                dropout_p, softmax_scale, causal, rng_state=None):
Tri Dao's avatar
Tri Dao committed
76
77
78
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
79
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
80
        dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
81
        max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state
Tri Dao's avatar
Tri Dao committed
82
    )
Tri Dao's avatar
Tri Dao committed
83
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
84
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
85
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
86
87


Tri Dao's avatar
Tri Dao committed
88
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
89
90

    @staticmethod
Tri Dao's avatar
Tri Dao committed
91
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
92
93
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
94
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
95
96
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
            causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
97
        )
Tri Dao's avatar
Tri Dao committed
98
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
99
100
101
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
102
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
103
104

    @staticmethod
Tri Dao's avatar
Tri Dao committed
105
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
106
107
108
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
109
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
110
            dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
111
            ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
Tri Dao's avatar
Tri Dao committed
112
113
114
115
116
117
118
119
120
121
122
        )
        dqkv = dqkv[..., :dout.shape[-1]]  # We could have padded the head dimension
        return dqkv, None, None, None, None


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
123
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
124
125
            qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
126
        )
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
            cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
142
            ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
Tri Dao's avatar
Tri Dao committed
143
144
145
        )
        dqkv = dqkv[..., :dout.shape[-1]]  # We could have padded the head dimension
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
146
147


Tri Dao's avatar
Tri Dao committed
148
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
149
150

    @staticmethod
Tri Dao's avatar
Tri Dao committed
151
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
152
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
153
            softmax_scale = q.shape[-1] ** (-0.5)
154
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
155
156
            q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
            return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
157
        )
Tri Dao's avatar
Tri Dao committed
158
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
159
160
161
162
163
164
165
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
166
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
167
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
168
169
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
170
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
171
            dout, q, k, v, out, softmax_lse,
172
173
            dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal,
            rng_state=rng_state
Tri Dao's avatar
Tri Dao committed
174
        )
Tri Dao's avatar
Tri Dao committed
175
176
177
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., :dout.shape[-1]]
        return dq, dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
178
179


Tri Dao's avatar
Tri Dao committed
180
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
181
182

    @staticmethod
Tri Dao's avatar
Tri Dao committed
183
184
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
                softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
185
186
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
187
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
188
189
            q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
190
        )
Tri Dao's avatar
Tri Dao committed
191
192
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
                              cu_seqlens_q, cu_seqlens_k, rng_state)
Tri Dao's avatar
Tri Dao committed
193
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
194
195
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
196
197
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
198
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
199
200

    @staticmethod
Tri Dao's avatar
Tri Dao committed
201
202
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
203
204
205
206
207
208
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
            cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
209
            ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
216
217
218
219
220
221
        )
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., :dout.shape[-1]]
        return dq, dkv, None, None, None, None, None, None, None, None


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
222
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
223
224
225
226
227
228
229
230
231
232
233
234
            q, k, v, dropout_p, softmax_scale, causal=causal,
            return_softmax=return_softmax and dropout_p > 0
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
235
236
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
237
            dout, q, k, v, out, softmax_lse,
238
239
            dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
            rng_state=rng_state
Tri Dao's avatar
Tri Dao committed
240
        )
Tri Dao's avatar
Tri Dao committed
241
242
243
244
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., :dout.shape[-1]]
        dv = dv[..., :dout.shape[-1]]
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
245
246


Tri Dao's avatar
Tri Dao committed
247
class FlashAttnVarlenFunc(torch.autograd.Function):
248
249

    @staticmethod
Tri Dao's avatar
Tri Dao committed
250
251
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
                softmax_scale, causal, return_softmax):
252
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
253
            softmax_scale = q.shape[-1] ** (-0.5)
254
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
255
256
            q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
257
        )
Tri Dao's avatar
Tri Dao committed
258
259
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
                              cu_seqlens_q, cu_seqlens_k, rng_state)
260
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
261
262
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
263
264
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
265
        return out if not return_softmax else (out, softmax_lse, S_dmask)
266
267
268

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
269
270
271
272
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
273
274
            ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
            rng_state=rng_state
275
        )
Tri Dao's avatar
Tri Dao committed
276
277
278
279
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., :dout.shape[-1]]
        dv = dv[..., :dout.shape[-1]]
        return dq, dk, dv, None, None, None, None, None, None, None, None
280
281


Tri Dao's avatar
Tri Dao committed
282
283
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                              return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
284
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
285
286
287
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
288
289
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
290

Tri Dao's avatar
Tri Dao committed
291
    Arguments:
Tri Dao's avatar
Tri Dao committed
292
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
293
294
295
296
297
298
299
300
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
301
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
302
303
304
305
306
307
308
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
309
    return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
Tri Dao's avatar
Tri Dao committed
310
311


Tri Dao's avatar
Tri Dao committed
312
313
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
                             return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
314
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
315
316
317
318
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
319
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
320
321
322
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
323
    Arguments:
Tri Dao's avatar
Tri Dao committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)


def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                    return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
349
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)


def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
                                     causal=False, return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
382
383
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
384
385
386
387
388
389

    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
390
391
392
393
394
395
396
397
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
398
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
399
400
401
402
403
404
405
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
406
407
408
    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
409
410


Tri Dao's avatar
Tri Dao committed
411
412
413
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p=0.0, softmax_scale=None, causal=False,
                                    return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
414
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
415
416
417
418
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
419
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
420
421
422
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
423
424
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
425
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
440
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
441
442
443
444
445
446
447
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
448
449
450
451
    return FlashAttnVarlenKVPackedFunc.apply(
        q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, softmax_scale, causal, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
452

453

Tri Dao's avatar
Tri Dao committed
454
455
456
457
458
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                           dropout_p=0.0, softmax_scale=None, causal=False,
                           return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
459
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
460
461
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
462
463

    Arguments:
Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
471
472
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
489
490
491
492
    return FlashAttnVarlenFunc.apply(
        q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, softmax_scale, causal, return_attn_probs
    )