flash_attn_interface.py 27.5 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
3
4
from einops import rearrange

5
6
7
8
9
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
# isort: on

Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
43
44
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
45
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
46
47
        q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
    )
48
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
49
50


Tri Dao's avatar
Tri Dao committed
51
52
53
54
55
56
57
58
59
60
61
62
63
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
64
65
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
66
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
81
82
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
83
    #     breakpoint()
84
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
85
86


Tri Dao's avatar
Tri Dao committed
87
88
89
def _flash_attn_backward(
    dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None
):
Tri Dao's avatar
Tri Dao committed
90
91
92
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
93
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        dropout_p,
        softmax_scale,
        causal,
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
108
109
110
111
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
131
132
133
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
134
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
154
    )
Tri Dao's avatar
Tri Dao committed
155
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
156
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
157
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
158
159


Tri Dao's avatar
Tri Dao committed
160
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
161
    @staticmethod
Tri Dao's avatar
Tri Dao committed
162
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
163
164
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
165
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
172
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
173
        )
Tri Dao's avatar
Tri Dao committed
174
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
175
176
177
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
178
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
179
180

    @staticmethod
Tri Dao's avatar
Tri Dao committed
181
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
182
183
184
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
185
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
192
193
194
195
196
197
198
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
199
        )
Tri Dao's avatar
Tri Dao committed
200
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
201
202
203
204
205
206
207
208
        return dqkv, None, None, None, None


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
209
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
216
217
218
219
220
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
221
        )
Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
229
230
231
232
233
234
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
252
        )
Tri Dao's avatar
Tri Dao committed
253
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
254
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
255
256


Tri Dao's avatar
Tri Dao committed
257
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
258
    @staticmethod
Tri Dao's avatar
Tri Dao committed
259
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
260
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
261
            softmax_scale = q.shape[-1] ** (-0.5)
262
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
268
269
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
270
        )
Tri Dao's avatar
Tri Dao committed
271
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
272
273
274
275
276
277
278
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
279
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
280
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
281
282
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
283
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
284
285
286
287
288
289
290
291
292
293
294
295
296
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
297
        )
Tri Dao's avatar
Tri Dao committed
298
299
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
300
        return dq, dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
301
302


Tri Dao's avatar
Tri Dao committed
303
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
304
    @staticmethod
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
311
312
313
314
315
316
317
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
318
319
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
320
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
335
336
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
337
338
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
339
340
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
341
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
342
343

    @staticmethod
Tri Dao's avatar
Tri Dao committed
344
345
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
346
347
348
349
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
367
        )
Tri Dao's avatar
Tri Dao committed
368
369
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
370
371
372
373
374
375
376
377
        return dq, dkv, None, None, None, None, None, None, None, None


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
378
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
379
380
381
382
383
384
385
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
386
387
388
389
390
391
392
393
394
395
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
396
397
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
404
405
406
407
408
409
410
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
411
        )
Tri Dao's avatar
Tri Dao committed
412
413
414
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
415
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
416
417


Tri Dao's avatar
Tri Dao committed
418
class FlashAttnVarlenFunc(torch.autograd.Function):
419
    @staticmethod
Tri Dao's avatar
Tri Dao committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
        return_softmax,
    ):
434
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
435
            softmax_scale = q.shape[-1] ** (-0.5)
436
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
451
452
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
453
454
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
455
456
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
457
        return out if not return_softmax else (out, softmax_lse, S_dmask)
458
459
460

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
461
462
463
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            rng_state=rng_state,
481
        )
Tri Dao's avatar
Tri Dao committed
482
483
484
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
485
        return dq, dk, dv, None, None, None, None, None, None, None, None
486
487


Tri Dao's avatar
Tri Dao committed
488
489
490
def flash_attn_qkvpacked_func(
    qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
Tri Dao's avatar
Tri Dao committed
491
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
492
493
494
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
495
496
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
497

Tri Dao's avatar
Tri Dao committed
498
    Arguments:
Tri Dao's avatar
Tri Dao committed
499
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
500
501
502
503
504
505
506
507
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
508
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
509
510
511
512
513
514
515
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
516
    return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
Tri Dao's avatar
Tri Dao committed
517
518


Tri Dao's avatar
Tri Dao committed
519
520
521
def flash_attn_kvpacked_func(
    q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
Tri Dao's avatar
Tri Dao committed
522
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
523
524
525
526
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
527
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
528
529
530
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
531
    Arguments:
Tri Dao's avatar
Tri Dao committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)


Tri Dao's avatar
Tri Dao committed
553
554
555
def flash_attn_func(
    q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
Tri Dao's avatar
Tri Dao committed
556
557
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
558
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)


Tri Dao's avatar
Tri Dao committed
585
586
587
588
589
590
591
592
593
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
594
595
596
597
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
598
599
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
600
601
602
603
604
605

    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
606
607
608
609
610
611
612
613
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
614
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
620
621
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
622
623
624
    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
625
626


Tri Dao's avatar
Tri Dao committed
627
628
629
630
631
632
633
634
635
636
637
638
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
639
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
640
641
642
643
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
644
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
645
646
647
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Tri Dao's avatar
Tri Dao committed
648
649
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
650
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
665
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
666
667
668
669
670
671
672
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
673
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
674
675
676
677
678
679
680
681
682
683
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
684
    )
Tri Dao's avatar
Tri Dao committed
685

686

Tri Dao's avatar
Tri Dao committed
687
688
689
690
691
692
693
694
695
696
697
698
699
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
700
701
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
702
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
703
704
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
705
706

    Arguments:
Tri Dao's avatar
Tri Dao committed
707
708
709
710
711
712
713
714
715
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
732
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
733
734
735
736
737
738
739
740
741
742
743
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
744
    )