flash_attn_interface.py 46.2 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

youkaichao's avatar
youkaichao committed
14
15
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
16

Tri Dao's avatar
Tri Dao committed
17
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
24
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
25
        return 128
Tri Dao's avatar
Tri Dao committed
26
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
27
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
28
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
29
        return 64
Tri Dao's avatar
Tri Dao committed
30
31
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
33
        else:
Tri Dao's avatar
Tri Dao committed
34
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
35
36
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
37
            return 64
Tri Dao's avatar
Tri Dao committed
38
        else:
Tri Dao's avatar
Tri Dao committed
39
            return 32
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
45
        return 64
Tri Dao's avatar
Tri Dao committed
46
47


Tri Dao's avatar
Tri Dao committed
48
def _flash_attn_forward(
Nicolas Patry's avatar
Nicolas Patry committed
49
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
Tri Dao's avatar
Tri Dao committed
50
):
Tri Dao's avatar
Tri Dao committed
51
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
52
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
53
54
55
56
        q,
        k,
        v,
        None,
57
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
63
        softcap,
Tri Dao's avatar
Tri Dao committed
64
65
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
66
    )
67
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
68
69


Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
77
78
79
80
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
81
82
83
84
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    return_softmax=False,
Tri Dao's avatar
Tri Dao committed
85
86
    block_table=None,
    leftpad_k=None,
87
    seqused_k=None,
Tri Dao's avatar
Tri Dao committed
88
):
Tri Dao's avatar
Tri Dao committed
89
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
90
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
97
        seqused_k,
Tri Dao's avatar
Tri Dao committed
98
        leftpad_k,
99
        block_table,
100
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
107
108
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
109
        softcap,
Tri Dao's avatar
Tri Dao committed
110
111
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
112
113
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
114
    #     breakpoint()
115
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
116
117


Tri Dao's avatar
Tri Dao committed
118
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
132
    softcap,
133
    alibi_slopes,
134
    deterministic,
Tri Dao's avatar
Tri Dao committed
135
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
136
):
Tri Dao's avatar
Tri Dao committed
137
138
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
139
140
141
142
143
144
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
151
152
153
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
154
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
155
156
157
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
158
159
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
160
        softcap,
161
        deterministic,
Tri Dao's avatar
Tri Dao committed
162
163
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
164
165
166
167
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
185
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
186
    softcap,
187
    alibi_slopes,
188
    deterministic,
Tri Dao's avatar
Tri Dao committed
189
190
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
191
192
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
193
194
195
196
197
198
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
199
200
201
202
203
204
205
206
207
208
209
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
210
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
211
212
213
214
215
216
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
217
218
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
219
        softcap,
220
        deterministic,
Tri Dao's avatar
Tri Dao committed
221
222
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
223
    )
Tri Dao's avatar
Tri Dao committed
224
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
225
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
226
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
227
228


Tri Dao's avatar
Tri Dao committed
229
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
230
    @staticmethod
Tri Dao's avatar
Tri Dao committed
231
    def forward(
232
233
234
235
236
237
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
238
        softcap,
239
240
241
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
242
    ):
Tri Dao's avatar
Tri Dao committed
243
244
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
245
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
246
247
248
249
250
251
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
252
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
253
            softcap=softcap,
254
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
255
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
256
        )
Tri Dao's avatar
Tri Dao committed
257
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
258
259
260
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
261
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
262
        ctx.softcap = softcap
263
        ctx.alibi_slopes = alibi_slopes
264
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
265
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
266
267

    @staticmethod
Tri Dao's avatar
Tri Dao committed
268
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
269
270
271
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
272
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
280
281
282
283
284
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
285
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
286
            ctx.softcap,
287
            ctx.alibi_slopes,
288
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
289
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
290
        )
Tri Dao's avatar
Tri Dao committed
291
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
292
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
293
294
295
296


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
297
298
299
300
301
302
303
304
305
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
306
        softcap,
307
        alibi_slopes,
308
        deterministic,
Tri Dao's avatar
Tri Dao committed
309
310
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
311
312
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
313
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
314
315
316
317
318
319
320
321
322
323
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
324
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
325
            softcap=softcap,
326
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
327
            return_softmax=return_softmax and dropout_p > 0,
328
            block_table=None,
Tri Dao's avatar
Tri Dao committed
329
        )
Tri Dao's avatar
Tri Dao committed
330
331
332
333
334
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
335
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
336
        ctx.softcap = softcap
337
        ctx.alibi_slopes = alibi_slopes
338
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
339
340
341
342
343
344
345
346
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
363
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
364
            ctx.softcap,
365
            ctx.alibi_slopes,
366
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
367
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
368
        )
Tri Dao's avatar
Tri Dao committed
369
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
370
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
371
372


Tri Dao's avatar
Tri Dao committed
373
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
374
    @staticmethod
Tri Dao's avatar
Tri Dao committed
375
    def forward(
376
377
378
379
380
381
382
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
383
        softcap,
384
385
386
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
387
    ):
Tri Dao's avatar
Tri Dao committed
388
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
389
            softmax_scale = q.shape[-1] ** (-0.5)
390
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
391
392
393
394
395
396
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
397
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
398
            softcap=softcap,
399
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
400
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
401
        )
Tri Dao's avatar
Tri Dao committed
402
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
403
404
405
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
406
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
407
        ctx.softcap = softcap
408
        ctx.alibi_slopes = alibi_slopes
409
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
410
411
412
413
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
414
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
415
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
416
417
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
418
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
419
420
421
422
423
424
425
426
427
428
429
430
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
431
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
432
            ctx.softcap,
433
            ctx.alibi_slopes,
434
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
435
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
436
        )
Tri Dao's avatar
Tri Dao committed
437
438
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
439
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
440
441


Tri Dao's avatar
Tri Dao committed
442
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
443
    @staticmethod
Tri Dao's avatar
Tri Dao committed
444
445
446
447
448
449
450
451
452
453
454
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
455
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
456
        softcap,
457
        alibi_slopes,
458
        deterministic,
Tri Dao's avatar
Tri Dao committed
459
460
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
461
462
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
463
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
471
472
473
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
474
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
475
            softcap=softcap,
476
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
477
            return_softmax=return_softmax and dropout_p > 0,
478
            block_table=None,
Tri Dao's avatar
Tri Dao committed
479
480
481
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
482
483
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
484
485
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
486
487
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
488
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
489
        ctx.softcap = softcap
490
        ctx.alibi_slopes = alibi_slopes
491
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
492
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
493
494

    @staticmethod
Tri Dao's avatar
Tri Dao committed
495
496
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
497
498
499
500
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
517
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
518
            ctx.softcap,
519
            ctx.alibi_slopes,
520
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
521
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
522
        )
Tri Dao's avatar
Tri Dao committed
523
524
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
525
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
526
527
528
529


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
530
    def forward(
531
532
533
534
535
536
537
538
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
539
        softcap,
540
541
542
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
543
    ):
Tri Dao's avatar
Tri Dao committed
544
545
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
546
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
547
548
549
550
551
552
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
553
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
554
            softcap=softcap,
555
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
556
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
557
558
559
560
561
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
562
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
563
        ctx.softcap = softcap
564
        ctx.alibi_slopes = alibi_slopes
565
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
566
567
568
569
570
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
571
572
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
573
574
575
576
577
578
579
580
581
582
583
584
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
585
            ctx.window_size,
586
            ctx.softcap,
587
            ctx.alibi_slopes,
588
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
589
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
590
        )
Tri Dao's avatar
Tri Dao committed
591
592
593
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
594
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
595
596


Tri Dao's avatar
Tri Dao committed
597
class FlashAttnVarlenFunc(torch.autograd.Function):
598
    @staticmethod
Tri Dao's avatar
Tri Dao committed
599
600
601
602
603
604
605
606
607
608
609
610
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
611
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
612
        softcap,
613
        alibi_slopes,
614
        deterministic,
Tri Dao's avatar
Tri Dao committed
615
        return_softmax,
616
        block_table,
Tri Dao's avatar
Tri Dao committed
617
    ):
618
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
619
            softmax_scale = q.shape[-1] ** (-0.5)
620
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
621
622
623
624
625
626
627
628
629
630
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
631
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
632
            softcap=softcap,
633
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
634
            return_softmax=return_softmax and dropout_p > 0,
635
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
636
637
638
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
639
640
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
641
642
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
643
644
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
645
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
646
        ctx.softcap = softcap
647
        ctx.alibi_slopes = alibi_slopes
648
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
649
        return out if not return_softmax else (out, softmax_lse, S_dmask)
650
651
652

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
653
654
655
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
672
            ctx.window_size,
673
            ctx.softcap,
674
            ctx.alibi_slopes,
675
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
676
            rng_state=rng_state,
677
        )
Tri Dao's avatar
Tri Dao committed
678
679
680
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
681
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
682
683


Tri Dao's avatar
Tri Dao committed
684
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
685
686
687
688
689
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
690
    softcap=0.0,  # <=0.0 means deactivate
691
    alibi_slopes=None,
692
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
693
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
694
):
Tri Dao's avatar
Tri Dao committed
695
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
696
697
698
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
699
700
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
701

Tri Dao's avatar
Tri Dao committed
702
703
704
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
705
    Arguments:
Tri Dao's avatar
Tri Dao committed
706
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
707
708
709
710
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
711
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
712
        softcap: float. Anything > 0 activates softcapping attention.
713
714
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
715
716
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
717
718
719
720
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
721
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
722
723
724
725
726
727
728
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
729
    return FlashAttnQKVPackedFunc.apply(
730
731
732
733
734
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
735
        softcap,
736
737
738
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
739
    )
Tri Dao's avatar
Tri Dao committed
740
741


Tri Dao's avatar
Tri Dao committed
742
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
743
744
745
746
747
748
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
749
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
750
    alibi_slopes=None,
751
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
752
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
753
):
Tri Dao's avatar
Tri Dao committed
754
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
755
756
757
758
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
759
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
760
761
762
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

763
764
765
766
767
768
769
770
771
772
773
774
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
775
776
777
778
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
779
    Arguments:
Tri Dao's avatar
Tri Dao committed
780
781
782
783
784
785
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
786
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
787
        softcap: float. Anything > 0 activates softcapping attention.
788
789
790
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
791
792
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
793
794
795
796
797
798
799
800
801
802
803
804
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
805
    return FlashAttnKVPackedFunc.apply(
806
807
808
809
810
811
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
812
        softcap,
813
814
815
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
816
    )
Tri Dao's avatar
Tri Dao committed
817
818


Tri Dao's avatar
Tri Dao committed
819
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
820
821
822
823
824
825
826
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
827
    softcap=0.0, # 0.0 means deactivated
828
    alibi_slopes=None,
829
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
830
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
831
):
Tri Dao's avatar
Tri Dao committed
832
833
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
834
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
835
836
837
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

838
839
840
841
842
843
844
845
846
847
848
849
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
850
851
852
853
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
854
855
856
857
858
859
860
861
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
862
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
863
864
865
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
866
867
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
868
869
870
871
872
873
874
875
876
877
878
879
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
880
    return FlashAttnFunc.apply(
881
882
883
884
885
886
887
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
888
        softcap,
889
890
891
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
892
    )
Tri Dao's avatar
Tri Dao committed
893
894


Tri Dao's avatar
Tri Dao committed
895
896
897
898
899
900
901
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
902
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
903
    softcap=0.0, # 0.0 means deactivated
904
    alibi_slopes=None,
905
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
906
907
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
908
909
910
911
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
912
913
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
914

Tri Dao's avatar
Tri Dao committed
915
916
917
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
918
919
920
921
922
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
923
924
925
926
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
927
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
928
        softcap: float. Anything > 0 activates softcapping attention.
929
930
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
931
932
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
933
934
935
936
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
937
        out: (total, nheads, headdim).
938
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
939
940
941
942
943
944
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
945
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
946
947
948
949
950
951
952
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
953
        softcap,
954
        alibi_slopes,
955
        deterministic,
Tri Dao's avatar
Tri Dao committed
956
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
957
    )
Tri Dao's avatar
Tri Dao committed
958
959


Tri Dao's avatar
Tri Dao committed
960
961
962
963
964
965
966
967
968
969
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
970
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
971
    softcap=0.0, # 0.0 means deactivated
972
    alibi_slopes=None,
973
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
974
975
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
976
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
977
978
979
980
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
981
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
982
983
984
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

985
986
987
988
989
990
991
992
993
994
995
996
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
997
998
999
1000
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1001
1002
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1003
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1014
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1015
        softcap: float. Anything > 0 activates softcapping attention.
1016
1017
1018
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1019
1020
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1021
1022
1023
1024
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1025
        out: (total, nheads, headdim).
1026
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1027
1028
1029
1030
1031
1032
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1033
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1043
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1044
        softcap,
1045
        alibi_slopes,
1046
        deterministic,
Tri Dao's avatar
Tri Dao committed
1047
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1048
    )
Tri Dao's avatar
Tri Dao committed
1049

1050

Tri Dao's avatar
Tri Dao committed
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1062
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1063
    softcap=0.0, # 0.0 means deactivated
1064
    alibi_slopes=None,
1065
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1066
    return_attn_probs=False,
1067
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1068
):
Tri Dao's avatar
Tri Dao committed
1069
1070
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1071
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1072
1073
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1074

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1087
1088
1089
1090
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1091
    Arguments:
Tri Dao's avatar
Tri Dao committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1101
1102
1103
1104
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1105
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1106
        softcap: float. Anything > 0 activates softcapping attention.
1107
1108
1109
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1110
1111
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1112
1113
1114
1115
1116
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1117
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1118
1119
1120
1121
1122
1123
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1124
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1135
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1136
        softcap,
1137
        alibi_slopes,
1138
        deterministic,
Tri Dao's avatar
Tri Dao committed
1139
        return_attn_probs,
1140
        block_table,
Tri Dao's avatar
Tri Dao committed
1141
    )
Tri Dao's avatar
Tri Dao committed
1142
1143
1144
1145
1146
1147
1148
1149


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1150
1151
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1152
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1153
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1154
    cache_leftpad: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1155
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1156
1157
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1158
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1159
    softcap=0.0, # 0.0 means deactivated
1160
    rotary_interleaved=True,
1161
    alibi_slopes=None,
1162
    num_splits=0,
1163
    return_softmax_lse=False,
Tri Dao's avatar
Tri Dao committed
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1175
1176
1177
1178
1179
1180
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1181
1182

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1201
1202
1203
1204
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1205
1206
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1207
1208
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1209
1210
1211
1212
1213
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1214
1215
1216
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1217
1218
1219
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1220
1221
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1222
1223
1224
1225
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1226
1227
        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
Tri Dao's avatar
Tri Dao committed
1228
1229
1230
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1231
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1232
        softcap: float. Anything > 0 activates softcapping attention.
1233
1234
1235
1236
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1237
1238
1239
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1240
1241
1242
1243
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1244
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1245
1246
1247

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1248
1249
1250
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1251
1252
1253
1254
1255
1256
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1257
1258
1259
1260
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1261
1262
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1263
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1264
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1265
1266
1267
1268
1269
1270
1271
1272
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1273
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1274
        cache_leftpad,
Tri Dao's avatar
Tri Dao committed
1275
        block_table,
1276
        alibi_slopes,
1277
1278
1279
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1280
1281
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
1282
        softcap,
1283
1284
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1285
    )
Tri Dao's avatar
Tri Dao committed
1286
    return (out, softmax_lse) if return_softmax_lse else out