flash_attn_interface.py 46.5 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(
Nicolas Patry's avatar
Nicolas Patry committed
47
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
Tri Dao's avatar
Tri Dao committed
48
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
62
        softcap,
Tri Dao's avatar
Tri Dao committed
63
64
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
65
    )
66
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
67
68


Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
80
    window_size,
Tri Dao's avatar
Tri Dao committed
81
    softcap,
82
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
83
    return_softmax,
Tri Dao's avatar
Tri Dao committed
84
85
    block_table=None,
    leftpad_k=None,
Tri Dao's avatar
Tri Dao committed
86
):
Tri Dao's avatar
Tri Dao committed
87
88
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
89
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
90
91
92
93
94
95
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
96
        None,
Tri Dao's avatar
Tri Dao committed
97
        leftpad_k,
98
        block_table,
99
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
106
107
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
108
        softcap,
Tri Dao's avatar
Tri Dao committed
109
110
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
111
112
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
113
    #     breakpoint()
114
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
115
116


Tri Dao's avatar
Tri Dao committed
117
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
128
129
130
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
131
    softcap,
132
    alibi_slopes,
133
    deterministic,
Tri Dao's avatar
Tri Dao committed
134
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
135
):
Tri Dao's avatar
Tri Dao committed
136
137
138
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
139
140
141
142
143
144
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
151
152
153
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
154
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
155
156
157
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
158
159
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
160
        softcap,
161
        deterministic,
Tri Dao's avatar
Tri Dao committed
162
163
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
164
165
166
167
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
185
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
186
    softcap,
187
    alibi_slopes,
188
    deterministic,
Tri Dao's avatar
Tri Dao committed
189
190
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
191
192
193
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
194
195
196
197
198
199
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
200
201
202
203
204
205
206
207
208
209
210
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
211
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
218
219
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
220
        softcap,
221
        deterministic,
Tri Dao's avatar
Tri Dao committed
222
223
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
224
    )
Tri Dao's avatar
Tri Dao committed
225
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
226
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
227
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
228
229


Tri Dao's avatar
Tri Dao committed
230
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
231
    @staticmethod
Tri Dao's avatar
Tri Dao committed
232
    def forward(
233
234
235
236
237
238
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
239
        softcap,
240
241
242
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
243
    ):
Tri Dao's avatar
Tri Dao committed
244
245
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
246
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
253
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
254
            softcap=softcap,
255
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
256
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
257
        )
Tri Dao's avatar
Tri Dao committed
258
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
259
260
261
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
262
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
263
        ctx.softcap = softcap
264
        ctx.alibi_slopes = alibi_slopes
265
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
266
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
267
268

    @staticmethod
Tri Dao's avatar
Tri Dao committed
269
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
270
271
272
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
273
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
274
275
276
277
278
279
280
281
282
283
284
285
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
286
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
287
            ctx.softcap,
288
            ctx.alibi_slopes,
289
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
290
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
291
        )
Tri Dao's avatar
Tri Dao committed
292
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
293
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
294
295
296
297


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
298
299
300
301
302
303
304
305
306
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
307
        softcap,
308
        alibi_slopes,
309
        deterministic,
Tri Dao's avatar
Tri Dao committed
310
311
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
312
313
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
314
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
315
316
317
318
319
320
321
322
323
324
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
325
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
326
            softcap=softcap,
327
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
328
            return_softmax=return_softmax and dropout_p > 0,
329
            block_table=None,
Tri Dao's avatar
Tri Dao committed
330
        )
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
336
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
337
        ctx.softcap = softcap
338
        ctx.alibi_slopes = alibi_slopes
339
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
340
341
342
343
344
345
346
347
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
364
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
365
            ctx.softcap,
366
            ctx.alibi_slopes,
367
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
368
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
369
        )
Tri Dao's avatar
Tri Dao committed
370
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
371
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
372
373


Tri Dao's avatar
Tri Dao committed
374
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
375
    @staticmethod
Tri Dao's avatar
Tri Dao committed
376
    def forward(
377
378
379
380
381
382
383
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
384
        softcap,
385
386
387
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
388
    ):
Tri Dao's avatar
Tri Dao committed
389
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
390
            softmax_scale = q.shape[-1] ** (-0.5)
391
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
392
393
394
395
396
397
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
398
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
399
            softcap=softcap,
400
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
401
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
402
        )
Tri Dao's avatar
Tri Dao committed
403
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
404
405
406
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
407
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
408
        ctx.softcap = softcap
409
        ctx.alibi_slopes = alibi_slopes
410
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
411
412
413
414
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
415
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
416
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
417
418
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
419
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
420
421
422
423
424
425
426
427
428
429
430
431
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
432
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
433
            ctx.softcap,
434
            ctx.alibi_slopes,
435
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
436
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
437
        )
Tri Dao's avatar
Tri Dao committed
438
439
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
440
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
441
442


Tri Dao's avatar
Tri Dao committed
443
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
444
    @staticmethod
Tri Dao's avatar
Tri Dao committed
445
446
447
448
449
450
451
452
453
454
455
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
456
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
457
        softcap,
458
        alibi_slopes,
459
        deterministic,
Tri Dao's avatar
Tri Dao committed
460
461
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
462
463
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
464
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
465
466
467
468
469
470
471
472
473
474
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
475
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
476
            softcap=softcap,
477
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
478
            return_softmax=return_softmax and dropout_p > 0,
479
            block_table=None,
Tri Dao's avatar
Tri Dao committed
480
481
482
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
483
484
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
485
486
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
487
488
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
489
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
490
        ctx.softcap = softcap
491
        ctx.alibi_slopes = alibi_slopes
492
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
493
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
494
495

    @staticmethod
Tri Dao's avatar
Tri Dao committed
496
497
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
498
499
500
501
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
518
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
519
            ctx.softcap,
520
            ctx.alibi_slopes,
521
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
522
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
523
        )
Tri Dao's avatar
Tri Dao committed
524
525
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
526
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
527
528
529
530


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
531
    def forward(
532
533
534
535
536
537
538
539
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
540
        softcap,
541
542
543
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
544
    ):
Tri Dao's avatar
Tri Dao committed
545
546
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
547
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
548
549
550
551
552
553
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
554
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
555
            softcap=softcap,
556
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
557
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
558
559
560
561
562
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
563
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
564
        ctx.softcap = softcap
565
        ctx.alibi_slopes = alibi_slopes
566
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
567
568
569
570
571
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
572
573
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
574
575
576
577
578
579
580
581
582
583
584
585
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
586
            ctx.window_size,
587
            ctx.softcap,
588
            ctx.alibi_slopes,
589
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
590
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
591
        )
Tri Dao's avatar
Tri Dao committed
592
593
594
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
595
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
596
597


Tri Dao's avatar
Tri Dao committed
598
class FlashAttnVarlenFunc(torch.autograd.Function):
599
    @staticmethod
Tri Dao's avatar
Tri Dao committed
600
601
602
603
604
605
606
607
608
609
610
611
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
612
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
613
        softcap,
614
        alibi_slopes,
615
        deterministic,
Tri Dao's avatar
Tri Dao committed
616
        return_softmax,
617
        block_table,
Tri Dao's avatar
Tri Dao committed
618
    ):
619
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
620
            softmax_scale = q.shape[-1] ** (-0.5)
621
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
622
623
624
625
626
627
628
629
630
631
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
632
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
633
            softcap=softcap,
634
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
635
            return_softmax=return_softmax and dropout_p > 0,
636
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
637
638
639
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
640
641
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
642
643
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
644
645
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
646
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
647
        ctx.softcap = softcap
648
        ctx.alibi_slopes = alibi_slopes
649
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
650
        return out if not return_softmax else (out, softmax_lse, S_dmask)
651
652
653

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
654
655
656
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
673
            ctx.window_size,
674
            ctx.softcap,
675
            ctx.alibi_slopes,
676
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
677
            rng_state=rng_state,
678
        )
Tri Dao's avatar
Tri Dao committed
679
680
681
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
682
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
683
684


Tri Dao's avatar
Tri Dao committed
685
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
686
687
688
689
690
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
691
    softcap=0.0,  # <=0.0 means deactivate
692
    alibi_slopes=None,
693
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
694
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
695
):
Tri Dao's avatar
Tri Dao committed
696
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
697
698
699
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
700
701
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
702

Tri Dao's avatar
Tri Dao committed
703
704
705
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
706
    Arguments:
Tri Dao's avatar
Tri Dao committed
707
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
708
709
710
711
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
712
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
713
        softcap: float. Anything > 0 activates softcapping attention.
714
715
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
716
717
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
718
719
720
721
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
722
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
723
724
725
726
727
728
729
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
730
    return FlashAttnQKVPackedFunc.apply(
731
732
733
734
735
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Tri Dao's avatar
Tri Dao committed
736
        softcap,
737
738
739
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
740
    )
Tri Dao's avatar
Tri Dao committed
741
742


Tri Dao's avatar
Tri Dao committed
743
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
744
745
746
747
748
749
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
750
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
751
    alibi_slopes=None,
752
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
753
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
754
):
Tri Dao's avatar
Tri Dao committed
755
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
756
757
758
759
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
760
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
761
762
763
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

764
765
766
767
768
769
770
771
772
773
774
775
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
776
777
778
779
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
780
    Arguments:
Tri Dao's avatar
Tri Dao committed
781
782
783
784
785
786
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
787
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
788
        softcap: float. Anything > 0 activates softcapping attention.
789
790
791
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
792
793
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
794
795
796
797
798
799
800
801
802
803
804
805
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
806
    return FlashAttnKVPackedFunc.apply(
807
808
809
810
811
812
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
813
        softcap,
814
815
816
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
817
    )
Tri Dao's avatar
Tri Dao committed
818
819


Tri Dao's avatar
Tri Dao committed
820
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
821
822
823
824
825
826
827
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
828
    softcap=0.0, # 0.0 means deactivated
829
    alibi_slopes=None,
830
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
831
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
832
):
Tri Dao's avatar
Tri Dao committed
833
834
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
835
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
836
837
838
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

839
840
841
842
843
844
845
846
847
848
849
850
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
851
852
853
854
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
855
856
857
858
859
860
861
862
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
863
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
864
865
866
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
867
868
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
869
870
871
872
873
874
875
876
877
878
879
880
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
881
    return FlashAttnFunc.apply(
882
883
884
885
886
887
888
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
889
        softcap,
890
891
892
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
893
    )
Tri Dao's avatar
Tri Dao committed
894
895


Tri Dao's avatar
Tri Dao committed
896
897
898
899
900
901
902
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
903
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
904
    softcap=0.0, # 0.0 means deactivated
905
    alibi_slopes=None,
906
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
907
908
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
909
910
911
912
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
913
914
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
915

Tri Dao's avatar
Tri Dao committed
916
917
918
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
919
920
921
922
923
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
924
925
926
927
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
928
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
929
        softcap: float. Anything > 0 activates softcapping attention.
930
931
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
932
933
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
934
935
936
937
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
938
        out: (total, nheads, headdim).
939
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
940
941
942
943
944
945
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
946
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
947
948
949
950
951
952
953
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
954
        softcap,
955
        alibi_slopes,
956
        deterministic,
Tri Dao's avatar
Tri Dao committed
957
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
958
    )
Tri Dao's avatar
Tri Dao committed
959
960


Tri Dao's avatar
Tri Dao committed
961
962
963
964
965
966
967
968
969
970
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
971
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
972
    softcap=0.0, # 0.0 means deactivated
973
    alibi_slopes=None,
974
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
975
976
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
977
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
978
979
980
981
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
982
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
983
984
985
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

986
987
988
989
990
991
992
993
994
995
996
997
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
998
999
1000
1001
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1002
1003
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1004
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1015
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1016
        softcap: float. Anything > 0 activates softcapping attention.
1017
1018
1019
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1020
1021
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1022
1023
1024
1025
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1026
        out: (total, nheads, headdim).
1027
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1028
1029
1030
1031
1032
1033
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1034
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1035
1036
1037
1038
1039
1040
1041
1042
1043
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1044
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1045
        softcap,
1046
        alibi_slopes,
1047
        deterministic,
Tri Dao's avatar
Tri Dao committed
1048
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1049
    )
Tri Dao's avatar
Tri Dao committed
1050

1051

Tri Dao's avatar
Tri Dao committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1063
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1064
    softcap=0.0, # 0.0 means deactivated
1065
    alibi_slopes=None,
1066
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1067
    return_attn_probs=False,
1068
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1069
):
Tri Dao's avatar
Tri Dao committed
1070
1071
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1072
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1073
1074
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1075

1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1088
1089
1090
1091
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1092
    Arguments:
Tri Dao's avatar
Tri Dao committed
1093
1094
1095
1096
1097
1098
1099
1100
1101
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1102
1103
1104
1105
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1106
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1107
        softcap: float. Anything > 0 activates softcapping attention.
1108
1109
1110
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1111
1112
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1113
1114
1115
1116
1117
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1118
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1119
1120
1121
1122
1123
1124
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1125
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1136
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1137
        softcap,
1138
        alibi_slopes,
1139
        deterministic,
Tri Dao's avatar
Tri Dao committed
1140
        return_attn_probs,
1141
        block_table,
Tri Dao's avatar
Tri Dao committed
1142
    )
Tri Dao's avatar
Tri Dao committed
1143
1144
1145
1146
1147
1148
1149
1150


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1151
1152
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1153
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1154
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1155
    cache_leftpad: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1156
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1157
1158
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1159
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1160
    softcap=0.0, # 0.0 means deactivated
1161
    rotary_interleaved=True,
1162
    alibi_slopes=None,
1163
    num_splits=0,
1164
    return_softmax_lse=False,
Tri Dao's avatar
Tri Dao committed
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1176
1177
1178
1179
1180
1181
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1182
1183

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1202
1203
1204
1205
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1206
1207
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1208
1209
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1210
1211
1212
1213
1214
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1215
1216
1217
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1218
1219
1220
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1221
1222
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1223
1224
1225
1226
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1227
1228
        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
Tri Dao's avatar
Tri Dao committed
1229
1230
1231
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1232
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1233
        softcap: float. Anything > 0 activates softcapping attention.
1234
1235
1236
1237
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1238
1239
1240
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1241
1242
1243
1244
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1245
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1246
1247
1248

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1249
1250
1251
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1252
1253
1254
1255
1256
1257
1258
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1259
1260
1261
1262
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1263
1264
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1265
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1266
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1267
1268
1269
1270
1271
1272
1273
1274
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1275
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1276
        cache_leftpad,
Tri Dao's avatar
Tri Dao committed
1277
        block_table,
1278
        alibi_slopes,
1279
1280
1281
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1282
1283
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
1284
        softcap,
1285
1286
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1287
    )
Tri Dao's avatar
Tri Dao committed
1288
    return (out, softmax_lse) if return_softmax_lse else out