flash_attn_interface.py 46.3 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(
Nicolas Patry's avatar
Nicolas Patry committed
47
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
Tri Dao's avatar
Tri Dao committed
48
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
62
        softcap,
Tri Dao's avatar
Tri Dao committed
63
64
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
65
    )
66
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
67
68


Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
80
    window_size,
Tri Dao's avatar
Tri Dao committed
81
    softcap,
82
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
83
    return_softmax,
84
    block_table,
Tri Dao's avatar
Tri Dao committed
85
):
Tri Dao's avatar
Tri Dao committed
86
87
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
88
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
95
        None,
96
        block_table,
97
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
104
105
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
106
        softcap,
Tri Dao's avatar
Tri Dao committed
107
108
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
109
110
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
111
    #     breakpoint()
112
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
113
114


Tri Dao's avatar
Tri Dao committed
115
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
116
117
118
119
120
121
122
123
124
125
126
127
128
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
129
    softcap,
130
    alibi_slopes,
131
    deterministic,
Tri Dao's avatar
Tri Dao committed
132
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
133
):
Tri Dao's avatar
Tri Dao committed
134
135
136
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
137
138
139
140
141
142
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
143
144
145
146
147
148
149
150
151
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
152
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
153
154
155
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
156
157
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
158
        softcap,
159
        deterministic,
Tri Dao's avatar
Tri Dao committed
160
161
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
162
163
164
165
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
183
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
184
    softcap,
185
    alibi_slopes,
186
    deterministic,
Tri Dao's avatar
Tri Dao committed
187
188
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
189
190
191
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
192
193
194
195
196
197
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
206
207
208
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
209
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
216
217
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
218
        softcap,
219
        deterministic,
Tri Dao's avatar
Tri Dao committed
220
221
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
222
    )
Tri Dao's avatar
Tri Dao committed
223
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
224
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
225
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
226
227


Tri Dao's avatar
Tri Dao committed
228
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
229
    @staticmethod
Tri Dao's avatar
Tri Dao committed
230
    def forward(
231
232
233
234
235
236
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
237
        softcap,
238
239
240
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
241
    ):
Tri Dao's avatar
Tri Dao committed
242
243
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
244
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
245
246
247
248
249
250
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
251
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
252
            softcap=softcap,
253
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
254
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
255
        )
Tri Dao's avatar
Tri Dao committed
256
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
257
258
259
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
260
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
261
        ctx.softcap = softcap
262
        ctx.alibi_slopes = alibi_slopes
263
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
264
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
265
266

    @staticmethod
Tri Dao's avatar
Tri Dao committed
267
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
268
269
270
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
271
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
272
273
274
275
276
277
278
279
280
281
282
283
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
284
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
285
            ctx.softcap,
286
            ctx.alibi_slopes,
287
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
288
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
289
        )
Tri Dao's avatar
Tri Dao committed
290
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
291
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
292
293
294
295


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
305
        softcap,
306
        alibi_slopes,
307
        deterministic,
Tri Dao's avatar
Tri Dao committed
308
309
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
310
311
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
312
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
313
314
315
316
317
318
319
320
321
322
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
323
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
324
            softcap=softcap,
325
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
326
            return_softmax=return_softmax and dropout_p > 0,
327
            block_table=None,
Tri Dao's avatar
Tri Dao committed
328
        )
Tri Dao's avatar
Tri Dao committed
329
330
331
332
333
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
334
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
335
        ctx.softcap = softcap
336
        ctx.alibi_slopes = alibi_slopes
337
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
338
339
340
341
342
343
344
345
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
362
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
363
            ctx.softcap,
364
            ctx.alibi_slopes,
365
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
366
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
367
        )
Tri Dao's avatar
Tri Dao committed
368
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
369
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
370
371


Tri Dao's avatar
Tri Dao committed
372
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
373
    @staticmethod
Tri Dao's avatar
Tri Dao committed
374
    def forward(
375
376
377
378
379
380
381
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
382
        softcap,
383
384
385
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
386
    ):
Tri Dao's avatar
Tri Dao committed
387
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
388
            softmax_scale = q.shape[-1] ** (-0.5)
389
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
390
391
392
393
394
395
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
396
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
397
            softcap=softcap,
398
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
399
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
400
        )
Tri Dao's avatar
Tri Dao committed
401
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
402
403
404
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
405
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
406
        ctx.softcap = softcap
407
        ctx.alibi_slopes = alibi_slopes
408
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
409
410
411
412
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
413
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
414
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
415
416
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
417
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
418
419
420
421
422
423
424
425
426
427
428
429
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
430
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
431
            ctx.softcap,
432
            ctx.alibi_slopes,
433
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
434
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
435
        )
Tri Dao's avatar
Tri Dao committed
436
437
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
438
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
439
440


Tri Dao's avatar
Tri Dao committed
441
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
442
    @staticmethod
Tri Dao's avatar
Tri Dao committed
443
444
445
446
447
448
449
450
451
452
453
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
454
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
455
        softcap,
456
        alibi_slopes,
457
        deterministic,
Tri Dao's avatar
Tri Dao committed
458
459
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
460
461
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
462
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
463
464
465
466
467
468
469
470
471
472
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
473
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
474
            softcap=softcap,
475
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
476
            return_softmax=return_softmax and dropout_p > 0,
477
            block_table=None,
Tri Dao's avatar
Tri Dao committed
478
479
480
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
481
482
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
483
484
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
485
486
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
487
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
488
        ctx.softcap = softcap
489
        ctx.alibi_slopes = alibi_slopes
490
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
491
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
492
493

    @staticmethod
Tri Dao's avatar
Tri Dao committed
494
495
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
496
497
498
499
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
516
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
517
            ctx.softcap,
518
            ctx.alibi_slopes,
519
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
520
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
521
        )
Tri Dao's avatar
Tri Dao committed
522
523
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
524
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
525
526
527
528


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
529
    def forward(
530
531
532
533
534
535
536
537
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
538
        softcap,
539
540
541
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
542
    ):
Tri Dao's avatar
Tri Dao committed
543
544
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
545
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
546
547
548
549
550
551
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
552
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
553
            softcap=softcap,
554
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
555
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
556
557
558
559
560
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
561
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
562
        ctx.softcap = softcap
563
        ctx.alibi_slopes = alibi_slopes
564
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
565
566
567
568
569
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
570
571
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
572
573
574
575
576
577
578
579
580
581
582
583
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
584
            ctx.window_size,
585
            ctx.softcap,
586
            ctx.alibi_slopes,
587
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
588
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
589
        )
Tri Dao's avatar
Tri Dao committed
590
591
592
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
593
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
594
595


Tri Dao's avatar
Tri Dao committed
596
class FlashAttnVarlenFunc(torch.autograd.Function):
597
    @staticmethod
Tri Dao's avatar
Tri Dao committed
598
599
600
601
602
603
604
605
606
607
608
609
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
610
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
611
        softcap,
612
        alibi_slopes,
613
        deterministic,
Tri Dao's avatar
Tri Dao committed
614
        return_softmax,
615
        block_table,
Tri Dao's avatar
Tri Dao committed
616
    ):
617
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
618
            softmax_scale = q.shape[-1] ** (-0.5)
619
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
620
621
622
623
624
625
626
627
628
629
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
630
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
631
            softcap=softcap,
632
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
633
            return_softmax=return_softmax and dropout_p > 0,
634
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
635
636
637
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
638
639
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
640
641
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
642
643
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
644
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
645
        ctx.softcap = softcap
646
        ctx.alibi_slopes = alibi_slopes
647
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
648
        return out if not return_softmax else (out, softmax_lse, S_dmask)
649
650
651

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
652
653
654
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
671
            ctx.window_size,
672
            ctx.softcap,
673
            ctx.alibi_slopes,
674
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
675
            rng_state=rng_state,
676
        )
Tri Dao's avatar
Tri Dao committed
677
678
679
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
680
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
681
682


Tri Dao's avatar
Tri Dao committed
683
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
684
685
686
687
688
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
689
    softcap=0.0,  # <=0.0 means deactivate
690
    alibi_slopes=None,
691
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
692
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
693
):
Tri Dao's avatar
Tri Dao committed
694
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
695
696
697
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
698
699
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
700

Tri Dao's avatar
Tri Dao committed
701
702
703
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
704
    Arguments:
Tri Dao's avatar
Tri Dao committed
705
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
706
707
708
709
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
710
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
711
        softcap: float. Anything > 0 activates softcapping attention.
712
713
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
714
715
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
716
717
718
719
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
720
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
721
722
723
724
725
726
727
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
728
    return FlashAttnQKVPackedFunc.apply(
729
730
731
732
733
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
734
        softcap,
735
736
737
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
738
    )
Tri Dao's avatar
Tri Dao committed
739
740


Tri Dao's avatar
Tri Dao committed
741
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
742
743
744
745
746
747
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
748
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
749
    alibi_slopes=None,
750
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
751
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
752
):
Tri Dao's avatar
Tri Dao committed
753
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
754
755
756
757
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
758
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
759
760
761
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

762
763
764
765
766
767
768
769
770
771
772
773
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
774
775
776
777
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
778
    Arguments:
Tri Dao's avatar
Tri Dao committed
779
780
781
782
783
784
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
785
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
786
        softcap: float. Anything > 0 activates softcapping attention.
787
788
789
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
790
791
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
792
793
794
795
796
797
798
799
800
801
802
803
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
804
    return FlashAttnKVPackedFunc.apply(
805
806
807
808
809
810
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
811
        softcap,
812
813
814
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
815
    )
Tri Dao's avatar
Tri Dao committed
816
817


Tri Dao's avatar
Tri Dao committed
818
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
819
820
821
822
823
824
825
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
826
    softcap=0.0, # 0.0 means deactivated
827
    alibi_slopes=None,
828
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
829
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
830
):
Tri Dao's avatar
Tri Dao committed
831
832
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
833
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
834
835
836
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

837
838
839
840
841
842
843
844
845
846
847
848
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
849
850
851
852
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
853
854
855
856
857
858
859
860
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
861
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
862
863
864
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
865
866
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
867
868
869
870
871
872
873
874
875
876
877
878
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
879
    return FlashAttnFunc.apply(
880
881
882
883
884
885
886
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
887
        softcap,
888
889
890
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
891
    )
Tri Dao's avatar
Tri Dao committed
892
893


Tri Dao's avatar
Tri Dao committed
894
895
896
897
898
899
900
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
901
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
902
    softcap=0.0, # 0.0 means deactivated
903
    alibi_slopes=None,
904
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
905
906
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
907
908
909
910
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
911
912
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
913

Tri Dao's avatar
Tri Dao committed
914
915
916
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
917
918
919
920
921
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
922
923
924
925
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
926
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
927
        softcap: float. Anything > 0 activates softcapping attention.
928
929
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
930
931
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
932
933
934
935
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
936
        out: (total, nheads, headdim).
937
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
938
939
940
941
942
943
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
944
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
945
946
947
948
949
950
951
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
952
        softcap,
953
        alibi_slopes,
954
        deterministic,
Tri Dao's avatar
Tri Dao committed
955
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
956
    )
Tri Dao's avatar
Tri Dao committed
957
958


Tri Dao's avatar
Tri Dao committed
959
960
961
962
963
964
965
966
967
968
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
969
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
970
    softcap=0.0, # 0.0 means deactivated
971
    alibi_slopes=None,
972
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
973
974
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
975
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
976
977
978
979
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
980
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
981
982
983
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

984
985
986
987
988
989
990
991
992
993
994
995
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
996
997
998
999
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1000
1001
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1002
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1013
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1014
        softcap: float. Anything > 0 activates softcapping attention.
1015
1016
1017
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1018
1019
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1020
1021
1022
1023
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1024
        out: (total, nheads, headdim).
1025
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1026
1027
1028
1029
1030
1031
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1032
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1042
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1043
        softcap,
1044
        alibi_slopes,
1045
        deterministic,
Tri Dao's avatar
Tri Dao committed
1046
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1047
    )
Tri Dao's avatar
Tri Dao committed
1048

1049

Tri Dao's avatar
Tri Dao committed
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1061
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1062
    softcap=0.0, # 0.0 means deactivated
1063
    alibi_slopes=None,
1064
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1065
    return_attn_probs=False,
1066
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1067
):
Tri Dao's avatar
Tri Dao committed
1068
1069
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1070
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1071
1072
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1086
1087
1088
1089
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1090
    Arguments:
Tri Dao's avatar
Tri Dao committed
1091
1092
1093
1094
1095
1096
1097
1098
1099
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1100
1101
1102
1103
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1104
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1105
        softcap: float. Anything > 0 activates softcapping attention.
1106
1107
1108
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1109
1110
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1111
1112
1113
1114
1115
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1116
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1117
1118
1119
1120
1121
1122
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1123
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1134
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1135
        softcap,
1136
        alibi_slopes,
1137
        deterministic,
Tri Dao's avatar
Tri Dao committed
1138
        return_attn_probs,
1139
        block_table,
Tri Dao's avatar
Tri Dao committed
1140
    )
Tri Dao's avatar
Tri Dao committed
1141
1142
1143
1144
1145
1146
1147
1148


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1149
1150
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1151
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1152
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1153
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1154
1155
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1156
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1157
    softcap=0.0, # 0.0 means deactivated
1158
    rotary_interleaved=True,
1159
    alibi_slopes=None,
1160
    num_splits=0,
1161
    return_softmax_lse=False,
Tri Dao's avatar
Tri Dao committed
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1173
1174
1175
1176
1177
1178
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1179
1180

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1199
1200
1201
1202
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1203
1204
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1205
1206
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1207
1208
1209
1210
1211
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1212
1213
1214
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1215
1216
1217
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1218
1219
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1220
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1221
1222
1223
1224
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1225
1226
1227
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1228
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1229
        softcap: float. Anything > 0 activates softcapping attention.
1230
1231
1232
1233
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1234
1235
1236
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1237
1238
1239
1240
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1241
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1242
1243
1244

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1245
1246
1247
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1248
1249
1250
1251
1252
1253
1254
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1255
1256
1257
1258
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1259
1260
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1261
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1262
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1263
1264
1265
1266
1267
1268
1269
1270
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1271
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1272
        block_table,
1273
        alibi_slopes,
1274
1275
1276
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1277
1278
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
1279
        softcap,
1280
1281
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1282
    )
Tri Dao's avatar
Tri Dao committed
1283
    return (out, softmax_lse) if return_softmax_lse else out