flash_attn_interface.py 20 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
import torch
import torch.nn as nn
3
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
4

Tri Dao's avatar
Tri Dao committed
5
import flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
6
7


Tri Dao's avatar
Tri Dao committed
8
def _get_block_size(device, head_dim, is_dropout):
9
10
    assert head_dim % 8 == 0 and head_dim <= 128
    return 256 if head_dim <= 64 else 128
Tri Dao's avatar
Tri Dao committed
11
12


13
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
                        dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
                        generator=None):
    """
    num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
    it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
    Don't change it unless you know what you're doing.
    """
21
    softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
22
        q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
Tri Dao's avatar
Tri Dao committed
23
        softmax_scale, False, causal, return_softmax, num_splits, generator
Tri Dao's avatar
Tri Dao committed
24
25
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
26
27
    #     breakpoint()
    S_dmask = rest[0] if return_softmax else None
28
    return out, softmax_lse, rng_state, S_dmask
Tri Dao's avatar
Tri Dao committed
29
30


Tri Dao's avatar
Tri Dao committed
31
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
32
33
                         max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
                         rng_state=None, num_splits=0, generator=None):
Tri Dao's avatar
Tri Dao committed
34
    """
35
36
37
38
    num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
    not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
    Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
    as num_splits=3), so effectively the choices are 0, 1, and 2.
Tri Dao's avatar
Tri Dao committed
39
40
    This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
    """
41
    dout = dout.contiguous()  # CUDA code assumes that dout is contiguous
42
    _, _, _, softmax_d = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
43
        dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
44
45
        max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
        num_splits, generator, rng_state)
Tri Dao's avatar
Tri Dao committed
46
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
47
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
48
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
49
50


Tri Dao's avatar
Tri Dao committed
51
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
52
53

    @staticmethod
54
55
    def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
                return_softmax, deterministic):
Tri Dao's avatar
Tri Dao committed
56
57
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
58
        out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
59
60
61
            qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
            max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
            return_softmax=return_softmax
Tri Dao's avatar
Tri Dao committed
62
        )
Tri Dao's avatar
Tri Dao committed
63
        ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
Tri Dao's avatar
Tri Dao committed
64
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
65
        ctx.max_seqlen = max_seqlen
Tri Dao's avatar
Tri Dao committed
66
67
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
68
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
69
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
70
71

    @staticmethod
Tri Dao's avatar
Tri Dao committed
72
73
74
75
76
77
    def backward(ctx, dout, *args):
        qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        dqkv = torch.empty_like(qkv)
        _flash_attn_backward(
            dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
            dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
78
            ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
79
            rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
Tri Dao's avatar
Tri Dao committed
80
        )
81
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
82
83


Tri Dao's avatar
Tri Dao committed
84
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
85
86

    @staticmethod
Tri Dao's avatar
Tri Dao committed
87
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
88
                softmax_scale, causal, return_softmax, deterministic):
Tri Dao's avatar
Tri Dao committed
89
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
90
            softmax_scale = q.shape[-1] ** (-0.5)
91
        out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
92
93
            q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
            max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
        )
        ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
101
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
102
103
104
105
106
107
108
109
110
111
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq = torch.empty_like(q)
        dkv = torch.empty_like(kv)
        _flash_attn_backward(
            dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
            dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
112
            ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
113
            rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
Tri Dao's avatar
Tri Dao committed
114
        )
115
        return dq, dkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
116
117
118
119
120
121


class FlashAttnFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
122
                softmax_scale, causal, return_softmax, deterministic):
Tri Dao's avatar
Tri Dao committed
123
124
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
125
        out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
126
            q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
Tri Dao's avatar
Tri Dao committed
127
            dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
Tri Dao's avatar
Tri Dao committed
128
        )
Tri Dao's avatar
Tri Dao committed
129
        ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
Tri Dao's avatar
Tri Dao committed
130
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
131
132
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
133
134
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
135
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
136
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
137
138

    @staticmethod
Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
            dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
144
            ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
145
            rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
Tri Dao's avatar
Tri Dao committed
146
        )
147
        return dq, dk, dv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
148
149


150
151
152
153
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
154
                softmax_scale, causal, return_softmax, deterministic):
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # Save rng_state because the backward pass will regenerate the dropout mask
        if dropout_p > 0:
            rng_state0 = torch.cuda.get_rng_state()
            generator1 = torch.Generator(device='cuda')
            rng_state1 = generator1.get_state()
        else:
            rng_state0, generator1, rng_state1 = None, None, None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
        out = torch.empty_like(qkv[:, 0])
        _, softmax_lse0, S_dmask0 = _flash_attn_forward(
            qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1],
            cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale,
            causal=causal, return_softmax=return_softmax
        )
        s = torch.cuda.Stream()
        with torch.cuda.stream(s):
            _, softmax_lse1, S_dmask1 = _flash_attn_forward(
                qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:],
                cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale,
                causal=causal, return_softmax=return_softmax, generator=generator1
            )
        torch.cuda.current_stream().wait_stream(s)
        ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens,
                              rng_state0, rng_state1)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen0 = max_seqlen0
        ctx.max_seqlen1 = max_seqlen1
        ctx.batch_size0 = batch_size0
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
186
        ctx.deterministic = deterministic
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        if not return_softmax:
            return out
        else:
            max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2])
            max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3])
            softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])),
                                     F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))],
                                    dim=0)
            return out, softmax_lse, S_dmask0, S_dmask1

    @staticmethod
    def backward(ctx, dout, *args):
        qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors
        batch_size0 = ctx.batch_size0
        if rng_state0 is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state0)
        if rng_state1 is not None:
            generator1 = torch.Generator(device='cuda')
            generator1.set_state(rng_state1)
        else:
            generator1 = None
        dqkv = torch.empty_like(qkv)
        _flash_attn_backward(
            dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
            dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
            cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
214
            ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
215
216
217
218
219
220
221
        )
        s = torch.cuda.Stream()
        with torch.cuda.stream(s):
            _flash_attn_backward(
                dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
                dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
                cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
222
223
                ctx.softmax_scale, ctx.causal, generator=generator1,
                num_splits=1 if ctx.deterministic else 0,
224
225
226
227
            )
        torch.cuda.current_stream().wait_stream(s)
        if rng_state0 is not None:
            torch.cuda.set_rng_state(cur_rng_state)
228
        return dqkv, None, None, None, None, None, None, None, None, None
229
230


Tri Dao's avatar
Tri Dao committed
231
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
232
                                       causal=False, return_attn_probs=False, deterministic=False):
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
238
239
240
241
242
243
244
245
    """dropout_p should be set to 0.0 during evaluation
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
246
        deterministic: bool. Whether or not to ensure deterministic execution.
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
254
255
256
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
257
                                        causal, return_attn_probs, deterministic)
Tri Dao's avatar
Tri Dao committed
258
259
260
261


def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                      dropout_p, softmax_scale=None, causal=False,
262
                                      return_attn_probs=False, deterministic=False):
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    """dropout_p should be set to 0.0 during evaluation
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
280
        deterministic: bool. Whether or not to ensure deterministic execution.
Tri Dao's avatar
Tri Dao committed
281
282
283
284
285
286
287
288
289
290
291
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
                                       max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
292
                                       return_attn_probs, deterministic)
Tri Dao's avatar
Tri Dao committed
293
294
295


def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
296
297
                             dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
                             deterministic=False):
Tri Dao's avatar
Tri Dao committed
298
299
300
    """dropout_p should be set to 0.0 during evaluation
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
301
302
        k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
303
304
305
306
307
308
309
310
311
312
313
314
315
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
316
        deterministic: bool. Whether or not to ensure deterministic execution.
Tri Dao's avatar
Tri Dao committed
317
318
319
320
321
322
323
324
325
326
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
327
                               dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
Tri Dao's avatar
Tri Dao committed
328
329


330
331
def flash_attn_unpadded_qkvpacked_split_func(
        qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
332
        causal=False, return_attn_probs=False, deterministic=False):
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    """
    Split attention into 2 kernels running on 2 separate streams for performance reason:
    e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
    have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.

    dropout_p should be set to 0.0 during evaluation.

    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen0: int. Maximum sequence length in 1st part of the batch.
        max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
        batch_size0: int. Number of sequences in the 1st part of the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
354
        deterministic: bool. Whether or not to ensure deterministic execution.
355
356
357
358
359
360
361
362
363
364
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
365
366
                                             dropout_p, softmax_scale, causal, return_attn_probs,
                                             deterministic)
367
368


Tri Dao's avatar
Tri Dao committed
369
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
Tri Dao's avatar
Tri Dao committed
370
                     return_attn_probs=False):
Tri Dao's avatar
Tri Dao committed
371
372
    """For backward-compatibility only, will remove soon.
    dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
373
    """
Tri Dao's avatar
Tri Dao committed
374
375
    return flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, softmax_scale,
                                              causal, return_attn_probs)