flash_attn_interface.py 46.5 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(
Nicolas Patry's avatar
Nicolas Patry committed
47
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
Tri Dao's avatar
Tri Dao committed
48
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
62
        softcap,
Tri Dao's avatar
Tri Dao committed
63
64
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
65
    )
66
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
67
68


Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
80
81
82
83
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    return_softmax=False,
Tri Dao's avatar
Tri Dao committed
84
85
    block_table=None,
    leftpad_k=None,
86
    seqused_k=None,
Tri Dao's avatar
Tri Dao committed
87
):
Tri Dao's avatar
Tri Dao committed
88
89
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
90
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
97
        seqused_k,
Tri Dao's avatar
Tri Dao committed
98
        leftpad_k,
99
        block_table,
100
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
107
108
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
109
        softcap,
Tri Dao's avatar
Tri Dao committed
110
111
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
112
113
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
114
    #     breakpoint()
115
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
116
117


Tri Dao's avatar
Tri Dao committed
118
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
132
    softcap,
133
    alibi_slopes,
134
    deterministic,
Tri Dao's avatar
Tri Dao committed
135
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
136
):
Tri Dao's avatar
Tri Dao committed
137
138
139
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
140
141
142
143
144
145
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
146
147
148
149
150
151
152
153
154
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
155
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
156
157
158
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
159
160
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
161
        softcap,
162
        deterministic,
Tri Dao's avatar
Tri Dao committed
163
164
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
165
166
167
168
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
186
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
187
    softcap,
188
    alibi_slopes,
189
    deterministic,
Tri Dao's avatar
Tri Dao committed
190
191
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
192
193
194
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
195
196
197
198
199
200
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
201
202
203
204
205
206
207
208
209
210
211
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
212
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
213
214
215
216
217
218
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
219
220
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
221
        softcap,
222
        deterministic,
Tri Dao's avatar
Tri Dao committed
223
224
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
225
    )
Tri Dao's avatar
Tri Dao committed
226
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
227
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
228
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
229
230


Tri Dao's avatar
Tri Dao committed
231
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
232
    @staticmethod
Tri Dao's avatar
Tri Dao committed
233
    def forward(
234
235
236
237
238
239
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
240
        softcap,
241
242
243
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
244
    ):
Tri Dao's avatar
Tri Dao committed
245
246
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
247
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
248
249
250
251
252
253
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
254
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
255
            softcap=softcap,
256
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
257
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
258
        )
Tri Dao's avatar
Tri Dao committed
259
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
260
261
262
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
263
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
264
        ctx.softcap = softcap
265
        ctx.alibi_slopes = alibi_slopes
266
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
267
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
268
269

    @staticmethod
Tri Dao's avatar
Tri Dao committed
270
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
271
272
273
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
274
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
284
285
286
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
287
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
288
            ctx.softcap,
289
            ctx.alibi_slopes,
290
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
291
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
292
        )
Tri Dao's avatar
Tri Dao committed
293
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
294
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
295
296
297
298


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
304
305
306
307
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
308
        softcap,
309
        alibi_slopes,
310
        deterministic,
Tri Dao's avatar
Tri Dao committed
311
312
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
313
314
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
315
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
316
317
318
319
320
321
322
323
324
325
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
326
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
327
            softcap=softcap,
328
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
329
            return_softmax=return_softmax and dropout_p > 0,
330
            block_table=None,
Tri Dao's avatar
Tri Dao committed
331
        )
Tri Dao's avatar
Tri Dao committed
332
333
334
335
336
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
337
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
338
        ctx.softcap = softcap
339
        ctx.alibi_slopes = alibi_slopes
340
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
341
342
343
344
345
346
347
348
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
365
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
366
            ctx.softcap,
367
            ctx.alibi_slopes,
368
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
369
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
370
        )
Tri Dao's avatar
Tri Dao committed
371
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
372
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
373
374


Tri Dao's avatar
Tri Dao committed
375
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
376
    @staticmethod
Tri Dao's avatar
Tri Dao committed
377
    def forward(
378
379
380
381
382
383
384
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
385
        softcap,
386
387
388
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
389
    ):
Tri Dao's avatar
Tri Dao committed
390
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
391
            softmax_scale = q.shape[-1] ** (-0.5)
392
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
393
394
395
396
397
398
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
399
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
400
            softcap=softcap,
401
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
402
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
403
        )
Tri Dao's avatar
Tri Dao committed
404
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
405
406
407
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
408
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
409
        ctx.softcap = softcap
410
        ctx.alibi_slopes = alibi_slopes
411
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
412
413
414
415
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
416
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
417
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
418
419
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
420
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
421
422
423
424
425
426
427
428
429
430
431
432
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
433
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
434
            ctx.softcap,
435
            ctx.alibi_slopes,
436
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
437
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
438
        )
Tri Dao's avatar
Tri Dao committed
439
440
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
441
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
442
443


Tri Dao's avatar
Tri Dao committed
444
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
445
    @staticmethod
Tri Dao's avatar
Tri Dao committed
446
447
448
449
450
451
452
453
454
455
456
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
457
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
458
        softcap,
459
        alibi_slopes,
460
        deterministic,
Tri Dao's avatar
Tri Dao committed
461
462
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
463
464
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
465
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
466
467
468
469
470
471
472
473
474
475
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
476
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
477
            softcap=softcap,
478
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
479
            return_softmax=return_softmax and dropout_p > 0,
480
            block_table=None,
Tri Dao's avatar
Tri Dao committed
481
482
483
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
484
485
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
486
487
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
488
489
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
490
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
491
        ctx.softcap = softcap
492
        ctx.alibi_slopes = alibi_slopes
493
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
494
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
495
496

    @staticmethod
Tri Dao's avatar
Tri Dao committed
497
498
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
499
500
501
502
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
519
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
520
            ctx.softcap,
521
            ctx.alibi_slopes,
522
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
523
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
524
        )
Tri Dao's avatar
Tri Dao committed
525
526
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
527
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
528
529
530
531


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
532
    def forward(
533
534
535
536
537
538
539
540
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
541
        softcap,
542
543
544
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
545
    ):
Tri Dao's avatar
Tri Dao committed
546
547
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
548
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
549
550
551
552
553
554
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
555
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
556
            softcap=softcap,
557
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
558
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
564
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
565
        ctx.softcap = softcap
566
        ctx.alibi_slopes = alibi_slopes
567
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
568
569
570
571
572
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
573
574
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
575
576
577
578
579
580
581
582
583
584
585
586
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
587
            ctx.window_size,
588
            ctx.softcap,
589
            ctx.alibi_slopes,
590
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
591
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
592
        )
Tri Dao's avatar
Tri Dao committed
593
594
595
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
596
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
597
598


Tri Dao's avatar
Tri Dao committed
599
class FlashAttnVarlenFunc(torch.autograd.Function):
600
    @staticmethod
Tri Dao's avatar
Tri Dao committed
601
602
603
604
605
606
607
608
609
610
611
612
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
613
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
614
        softcap,
615
        alibi_slopes,
616
        deterministic,
Tri Dao's avatar
Tri Dao committed
617
        return_softmax,
618
        block_table,
Tri Dao's avatar
Tri Dao committed
619
    ):
620
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
621
            softmax_scale = q.shape[-1] ** (-0.5)
622
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
623
624
625
626
627
628
629
630
631
632
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
633
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
634
            softcap=softcap,
635
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
636
            return_softmax=return_softmax and dropout_p > 0,
637
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
638
639
640
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
641
642
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
643
644
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
645
646
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
647
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
648
        ctx.softcap = softcap
649
        ctx.alibi_slopes = alibi_slopes
650
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
651
        return out if not return_softmax else (out, softmax_lse, S_dmask)
652
653
654

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
655
656
657
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
674
            ctx.window_size,
675
            ctx.softcap,
676
            ctx.alibi_slopes,
677
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
678
            rng_state=rng_state,
679
        )
Tri Dao's avatar
Tri Dao committed
680
681
682
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
683
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
684
685


Tri Dao's avatar
Tri Dao committed
686
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
687
688
689
690
691
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
692
    softcap=0.0,  # <=0.0 means deactivate
693
    alibi_slopes=None,
694
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
695
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
696
):
Tri Dao's avatar
Tri Dao committed
697
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
698
699
700
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
701
702
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
703

Tri Dao's avatar
Tri Dao committed
704
705
706
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
707
    Arguments:
Tri Dao's avatar
Tri Dao committed
708
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
709
710
711
712
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
713
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
714
        softcap: float. Anything > 0 activates softcapping attention.
715
716
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
717
718
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
719
720
721
722
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
723
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
724
725
726
727
728
729
730
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
731
    return FlashAttnQKVPackedFunc.apply(
732
733
734
735
736
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
737
        softcap,
738
739
740
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
741
    )
Tri Dao's avatar
Tri Dao committed
742
743


Tri Dao's avatar
Tri Dao committed
744
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
745
746
747
748
749
750
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
751
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
752
    alibi_slopes=None,
753
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
754
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
755
):
Tri Dao's avatar
Tri Dao committed
756
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
757
758
759
760
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
761
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
762
763
764
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

765
766
767
768
769
770
771
772
773
774
775
776
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
777
778
779
780
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
781
    Arguments:
Tri Dao's avatar
Tri Dao committed
782
783
784
785
786
787
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
788
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
789
        softcap: float. Anything > 0 activates softcapping attention.
790
791
792
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
793
794
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
795
796
797
798
799
800
801
802
803
804
805
806
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
807
    return FlashAttnKVPackedFunc.apply(
808
809
810
811
812
813
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
814
        softcap,
815
816
817
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
818
    )
Tri Dao's avatar
Tri Dao committed
819
820


Tri Dao's avatar
Tri Dao committed
821
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
822
823
824
825
826
827
828
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
829
    softcap=0.0, # 0.0 means deactivated
830
    alibi_slopes=None,
831
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
832
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
833
):
Tri Dao's avatar
Tri Dao committed
834
835
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
836
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
837
838
839
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

840
841
842
843
844
845
846
847
848
849
850
851
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
852
853
854
855
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
856
857
858
859
860
861
862
863
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
864
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
865
866
867
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
868
869
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
870
871
872
873
874
875
876
877
878
879
880
881
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
882
    return FlashAttnFunc.apply(
883
884
885
886
887
888
889
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
890
        softcap,
891
892
893
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
894
    )
Tri Dao's avatar
Tri Dao committed
895
896


Tri Dao's avatar
Tri Dao committed
897
898
899
900
901
902
903
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
904
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
905
    softcap=0.0, # 0.0 means deactivated
906
    alibi_slopes=None,
907
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
908
909
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
910
911
912
913
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
914
915
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
916

Tri Dao's avatar
Tri Dao committed
917
918
919
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
920
921
922
923
924
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
925
926
927
928
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
929
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
930
        softcap: float. Anything > 0 activates softcapping attention.
931
932
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
933
934
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
935
936
937
938
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
939
        out: (total, nheads, headdim).
940
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
941
942
943
944
945
946
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
947
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
948
949
950
951
952
953
954
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
955
        softcap,
956
        alibi_slopes,
957
        deterministic,
Tri Dao's avatar
Tri Dao committed
958
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
959
    )
Tri Dao's avatar
Tri Dao committed
960
961


Tri Dao's avatar
Tri Dao committed
962
963
964
965
966
967
968
969
970
971
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
972
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
973
    softcap=0.0, # 0.0 means deactivated
974
    alibi_slopes=None,
975
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
976
977
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
978
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
979
980
981
982
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
983
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
984
985
986
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

987
988
989
990
991
992
993
994
995
996
997
998
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
999
1000
1001
1002
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1003
1004
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1005
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1016
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1017
        softcap: float. Anything > 0 activates softcapping attention.
1018
1019
1020
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1021
1022
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1023
1024
1025
1026
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1027
        out: (total, nheads, headdim).
1028
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1029
1030
1031
1032
1033
1034
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1035
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1045
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1046
        softcap,
1047
        alibi_slopes,
1048
        deterministic,
Tri Dao's avatar
Tri Dao committed
1049
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1050
    )
Tri Dao's avatar
Tri Dao committed
1051

1052

Tri Dao's avatar
Tri Dao committed
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1064
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1065
    softcap=0.0, # 0.0 means deactivated
1066
    alibi_slopes=None,
1067
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1068
    return_attn_probs=False,
1069
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1070
):
Tri Dao's avatar
Tri Dao committed
1071
1072
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1073
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1074
1075
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1076

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1089
1090
1091
1092
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1093
    Arguments:
Tri Dao's avatar
Tri Dao committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1103
1104
1105
1106
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1107
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1108
        softcap: float. Anything > 0 activates softcapping attention.
1109
1110
1111
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1112
1113
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1114
1115
1116
1117
1118
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1119
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1120
1121
1122
1123
1124
1125
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1126
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1137
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1138
        softcap,
1139
        alibi_slopes,
1140
        deterministic,
Tri Dao's avatar
Tri Dao committed
1141
        return_attn_probs,
1142
        block_table,
Tri Dao's avatar
Tri Dao committed
1143
    )
Tri Dao's avatar
Tri Dao committed
1144
1145
1146
1147
1148
1149
1150
1151


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1152
1153
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1154
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1155
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1156
    cache_leftpad: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1157
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1158
1159
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1160
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1161
    softcap=0.0, # 0.0 means deactivated
1162
    rotary_interleaved=True,
1163
    alibi_slopes=None,
1164
    num_splits=0,
1165
    return_softmax_lse=False,
Tri Dao's avatar
Tri Dao committed
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1177
1178
1179
1180
1181
1182
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1183
1184

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1203
1204
1205
1206
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1207
1208
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1209
1210
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1211
1212
1213
1214
1215
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1216
1217
1218
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1219
1220
1221
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1222
1223
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1224
1225
1226
1227
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1228
1229
        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
Tri Dao's avatar
Tri Dao committed
1230
1231
1232
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1233
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1234
        softcap: float. Anything > 0 activates softcapping attention.
1235
1236
1237
1238
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1239
1240
1241
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1242
1243
1244
1245
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1246
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1247
1248
1249

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1250
1251
1252
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1253
1254
1255
1256
1257
1258
1259
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1260
1261
1262
1263
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1264
1265
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1266
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1267
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1268
1269
1270
1271
1272
1273
1274
1275
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1276
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1277
        cache_leftpad,
Tri Dao's avatar
Tri Dao committed
1278
        block_table,
1279
        alibi_slopes,
1280
1281
1282
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1283
1284
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
1285
        softcap,
1286
1287
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1288
    )
Tri Dao's avatar
Tri Dao committed
1289
    return (out, softmax_lse) if return_softmax_lse else out