flash_attn_interface.py 27.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
import torch
import torch.nn as nn

Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import flash_attn_2_cuda as flash_attn_cuda
from einops import rearrange


def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
40
41
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Tri Dao's avatar
Tri Dao committed
42
43
44
45
46
47
48
49
    out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
        q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
    )
    return out, q, k, v, out_padded, softmax_lse, S_dmask


def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                               dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
50
51
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Tri Dao's avatar
Tri Dao committed
52
53
54
    out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
        q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
        softmax_scale, False, causal, return_softmax, None
Tri Dao's avatar
Tri Dao committed
55
56
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
57
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
58
    return out, q, k, v, out_padded, softmax_lse, S_dmask
Tri Dao's avatar
Tri Dao committed
59
60


Tri Dao's avatar
Tri Dao committed
61
62
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
                         dropout_p, softmax_scale, causal):
Tri Dao's avatar
Tri Dao committed
63
64
65
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
71
72
73
74
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
        dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
    )
    return dq, dk, dv, softmax_d


def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                dropout_p, softmax_scale, causal):
Tri Dao's avatar
Tri Dao committed
75
76
77
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
78
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
79
        dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
Tri Dao's avatar
Tri Dao committed
80
81
        max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
    )
Tri Dao's avatar
Tri Dao committed
82
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
83
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
84
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
85
86


Tri Dao's avatar
Tri Dao committed
87
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
88
89

    @staticmethod
Tri Dao's avatar
Tri Dao committed
90
91
92
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
Tri Dao's avatar
Tri Dao committed
93
94
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
95
96
97
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
            qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
            causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
98
        )
Tri Dao's avatar
Tri Dao committed
99
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
100
101
102
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
103
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
104
105

    @staticmethod
Tri Dao's avatar
Tri Dao committed
106
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
107
108
109
110
111
112
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
113
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
            ctx.dropout_p, ctx.softmax_scale, ctx.causal
        )
        dqkv = dqkv[..., :dout.shape[-1]]  # We could have padded the head dimension
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
            qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
134
        )
Tri Dao's avatar
Tri Dao committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
            cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
            ctx.dropout_p, ctx.softmax_scale, ctx.causal
        )
        dqkv = dqkv[..., :dout.shape[-1]]  # We could have padded the head dimension
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
159
160


Tri Dao's avatar
Tri Dao committed
161
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
162
163

    @staticmethod
Tri Dao's avatar
Tri Dao committed
164
165
166
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
Tri Dao's avatar
Tri Dao committed
167
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
168
            softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
169
170
171
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
            q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
            return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
172
        )
Tri Dao's avatar
Tri Dao committed
173
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
179
180
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
181
182
183
184
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
Tri Dao's avatar
Tri Dao committed
185
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
186
187
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
188
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
189
190
            dout, q, k, v, out, softmax_lse,
            dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal
Tri Dao's avatar
Tri Dao committed
191
        )
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., :dout.shape[-1]]
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dq, dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
197
198


Tri Dao's avatar
Tri Dao committed
199
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
200
201

    @staticmethod
Tri Dao's avatar
Tri Dao committed
202
203
204
205
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
                softmax_scale, causal, return_softmax):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
Tri Dao's avatar
Tri Dao committed
206
207
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
208
209
210
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
            q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
Tri Dao's avatar
Tri Dao committed
211
        )
Tri Dao's avatar
Tri Dao committed
212
213
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
                              cu_seqlens_q, cu_seqlens_k, rng_state)
Tri Dao's avatar
Tri Dao committed
214
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
215
216
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
217
218
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
219
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
220
221

    @staticmethod
Tri Dao's avatar
Tri Dao committed
222
223
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
            cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
            ctx.dropout_p, ctx.softmax_scale, ctx.causal
        )
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., :dout.shape[-1]]
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dq, dkv, None, None, None, None, None, None, None, None


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
            q, k, v, dropout_p, softmax_scale, causal=causal,
            return_softmax=return_softmax and dropout_p > 0
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
Tri Dao's avatar
Tri Dao committed
266
267
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
268
269
            dout, q, k, v, out, softmax_lse,
            dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal
Tri Dao's avatar
Tri Dao committed
270
        )
Tri Dao's avatar
Tri Dao committed
271
272
273
274
275
276
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., :dout.shape[-1]]
        dv = dv[..., :dout.shape[-1]]
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
277
278


Tri Dao's avatar
Tri Dao committed
279
class FlashAttnVarlenFunc(torch.autograd.Function):
280
281

    @staticmethod
Tri Dao's avatar
Tri Dao committed
282
283
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
                softmax_scale, causal, return_softmax):
284
        # Save rng_state because the backward pass will regenerate the dropout mask
Tri Dao's avatar
Tri Dao committed
285
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
286
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
287
288
289
290
            softmax_scale = q.shape[-1] ** (-0.5)
        out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
            q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
291
        )
Tri Dao's avatar
Tri Dao committed
292
293
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
                              cu_seqlens_q, cu_seqlens_k, rng_state)
294
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
295
296
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
297
298
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
299
        return out if not return_softmax else (out, softmax_lse, S_dmask)
300
301
302

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
303
304
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        if rng_state is not None:
305
            cur_rng_state = torch.cuda.get_rng_state()
Tri Dao's avatar
Tri Dao committed
306
307
308
309
310
            torch.cuda.set_rng_state(rng_state)
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
            dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
            ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
311
        )
Tri Dao's avatar
Tri Dao committed
312
313
314
315
        dq = dq[..., :dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., :dout.shape[-1]]
        dv = dv[..., :dout.shape[-1]]
        if rng_state is not None:
316
            torch.cuda.set_rng_state(cur_rng_state)
Tri Dao's avatar
Tri Dao committed
317
        return dq, dk, dv, None, None, None, None, None, None, None, None
318
319


Tri Dao's avatar
Tri Dao committed
320
321
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                              return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
322
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
323
324
325
326
327
328
329
330
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
331
    Arguments:
Tri Dao's avatar
Tri Dao committed
332
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
333
334
335
336
337
338
339
340
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
341
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
342
343
344
345
346
347
348
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
349
    return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
Tri Dao's avatar
Tri Dao committed
350
351


Tri Dao's avatar
Tri Dao committed
352
353
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
                             return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
354
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
360
361
362
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
363
    Arguments:
Tri Dao's avatar
Tri Dao committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)


def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                    return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)


def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
                                     causal=False, return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
430
431
432
433
434
435
436
437
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
438
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
439
440
441
442
443
444
445
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
446
447
448
    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
449
450


Tri Dao's avatar
Tri Dao committed
451
452
453
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p=0.0, softmax_scale=None, causal=False,
                                    return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
454
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
455
456
457
458
459
460
461
462
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
463
464
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
465
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
480
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
481
482
483
484
485
486
487
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
488
489
490
491
    return FlashAttnVarlenKVPackedFunc.apply(
        q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, softmax_scale, causal, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
492

493

Tri Dao's avatar
Tri Dao committed
494
495
496
497
498
499
500
501
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                           dropout_p=0.0, softmax_scale=None, causal=False,
                           return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
    than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
502
503

    Arguments:
Tri Dao's avatar
Tri Dao committed
504
505
506
507
508
509
510
511
512
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
529
530
531
532
    return FlashAttnVarlenFunc.apply(
        q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, softmax_scale, causal, return_attn_probs
    )