flash_attn_interface.py 38.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
3
4
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
5

6
7
8
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
9

10
11
# isort: on

Tri Dao's avatar
Tri Dao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


Tri Dao's avatar
Tri Dao committed
44
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
45
46
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
47
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
48
49
50
51
52
53
54
55
56
57
58
        q,
        k,
        v,
        None,
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
59
    )
60
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
61
62


Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
68
69
70
71
72
73
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
74
    window_size,
Tri Dao's avatar
Tri Dao committed
75
76
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
77
78
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
79
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
85
86
87
88
89
90
91
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
92
93
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
94
95
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
96
97
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
98
    #     breakpoint()
99
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
100
101


Tri Dao's avatar
Tri Dao committed
102
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
117
):
Tri Dao's avatar
Tri Dao committed
118
119
120
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
121
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
122
123
124
125
126
127
128
129
130
131
132
133
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
134
135
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
136
137
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
138
139
140
141
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
159
    window_size,
Tri Dao's avatar
Tri Dao committed
160
161
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
162
163
164
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
165
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
183
184
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
185
186
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
187
    )
Tri Dao's avatar
Tri Dao committed
188
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
189
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
190
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
191
192


Tri Dao's avatar
Tri Dao committed
193
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
194
    @staticmethod
Tri Dao's avatar
Tri Dao committed
195
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
196
197
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
198
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
199
200
201
202
203
204
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
205
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
206
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
207
        )
Tri Dao's avatar
Tri Dao committed
208
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
209
210
211
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
212
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
213
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
214
215

    @staticmethod
Tri Dao's avatar
Tri Dao committed
216
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
217
218
219
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
220
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
221
222
223
224
225
226
227
228
229
230
231
232
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
233
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
234
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
235
        )
Tri Dao's avatar
Tri Dao committed
236
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
237
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
238
239
240
241


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
242
243
244
245
246
247
248
249
250
251
252
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
253
254
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
255
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
256
257
258
259
260
261
262
263
264
265
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
266
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
267
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
268
        )
Tri Dao's avatar
Tri Dao committed
269
270
271
272
273
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
274
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
299
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
300
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
301
        )
Tri Dao's avatar
Tri Dao committed
302
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
303
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
304
305


Tri Dao's avatar
Tri Dao committed
306
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
307
    @staticmethod
Tri Dao's avatar
Tri Dao committed
308
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
309
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
310
            softmax_scale = q.shape[-1] ** (-0.5)
311
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
317
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
318
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
319
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
320
        )
Tri Dao's avatar
Tri Dao committed
321
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
322
323
324
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
325
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
326
327
328
329
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
330
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
331
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
332
333
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
334
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
335
336
337
338
339
340
341
342
343
344
345
346
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
347
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
348
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
349
        )
Tri Dao's avatar
Tri Dao committed
350
351
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
352
        return dq, dkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
353
354


Tri Dao's avatar
Tri Dao committed
355
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
356
    @staticmethod
Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
362
363
364
365
366
367
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
368
        window_size,
Tri Dao's avatar
Tri Dao committed
369
370
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
371
372
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
373
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
374
375
376
377
378
379
380
381
382
383
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
384
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
385
386
387
388
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
389
390
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
391
392
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
393
394
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
395
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
396
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
397
398

    @staticmethod
Tri Dao's avatar
Tri Dao committed
399
400
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
401
402
403
404
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
421
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
422
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
423
        )
Tri Dao's avatar
Tri Dao committed
424
425
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
426
        return dq, dkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
427
428
429
430


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
431
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
432
433
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
434
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
435
436
437
438
439
440
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
441
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
442
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
443
444
445
446
447
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
448
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
449
450
451
452
453
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
454
455
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
456
457
458
459
460
461
462
463
464
465
466
467
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
468
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
469
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
470
        )
Tri Dao's avatar
Tri Dao committed
471
472
473
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
474
        return dq, dk, dv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
475
476


Tri Dao's avatar
Tri Dao committed
477
class FlashAttnVarlenFunc(torch.autograd.Function):
478
    @staticmethod
Tri Dao's avatar
Tri Dao committed
479
480
481
482
483
484
485
486
487
488
489
490
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
491
        window_size,
Tri Dao's avatar
Tri Dao committed
492
493
        return_softmax,
    ):
494
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
495
            softmax_scale = q.shape[-1] ** (-0.5)
496
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
497
498
499
500
501
502
503
504
505
506
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
507
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
508
509
510
511
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
512
513
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
514
515
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
516
517
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
518
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
519
        return out if not return_softmax else (out, softmax_lse, S_dmask)
520
521
522

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
523
524
525
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
542
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
543
            rng_state=rng_state,
544
        )
Tri Dao's avatar
Tri Dao committed
545
546
547
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
548
        return dq, dk, dv, None, None, None, None, None, None, None, None, None
549
550


Tri Dao's avatar
Tri Dao committed
551
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
552
553
554
555
556
557
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
558
):
Tri Dao's avatar
Tri Dao committed
559
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
560
561
562
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
563
564
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
565

Tri Dao's avatar
Tri Dao committed
566
567
568
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
569
    Arguments:
Tri Dao's avatar
Tri Dao committed
570
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
571
572
573
574
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
575
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
576
577
578
579
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
580
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
581
582
583
584
585
586
587
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
588
589
590
    return FlashAttnQKVPackedFunc.apply(
        qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
591
592


Tri Dao's avatar
Tri Dao committed
593
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
594
595
596
597
598
599
600
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
601
):
Tri Dao's avatar
Tri Dao committed
602
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
603
604
605
606
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
607
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
608
609
610
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

611
612
613
614
615
616
617
618
619
620
621
622
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
623
624
625
626
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
627
    Arguments:
Tri Dao's avatar
Tri Dao committed
628
629
630
631
632
633
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
634
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
635
636
637
638
639
640
641
642
643
644
645
646
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
647
648
649
    return FlashAttnKVPackedFunc.apply(
        q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
650
651


Tri Dao's avatar
Tri Dao committed
652
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
653
654
655
656
657
658
659
660
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
661
):
Tri Dao's avatar
Tri Dao committed
662
663
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
664
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
665
666
667
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

668
669
670
671
672
673
674
675
676
677
678
679
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
680
681
682
683
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
684
685
686
687
688
689
690
691
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
692
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
693
694
695
696
697
698
699
700
701
702
703
704
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
705
706
707
    return FlashAttnFunc.apply(
        q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
708
709


Tri Dao's avatar
Tri Dao committed
710
711
712
713
714
715
716
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
717
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
718
719
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
720
721
722
723
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
724
725
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
726

Tri Dao's avatar
Tri Dao committed
727
728
729
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
730
731
732
733
734
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
735
736
737
738
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
739
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
740
741
742
743
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
744
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
745
746
747
748
749
750
751
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
752
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
753
754
755
756
757
758
759
760
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
761
    )
Tri Dao's avatar
Tri Dao committed
762
763


Tri Dao's avatar
Tri Dao committed
764
765
766
767
768
769
770
771
772
773
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
774
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
775
776
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
777
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
778
779
780
781
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
782
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
783
784
785
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

786
787
788
789
790
791
792
793
794
795
796
797
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
798
799
800
801
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
802
803
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
804
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
805
806
807
808
809
810
811
812
813
814
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
815
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
816
817
818
819
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
820
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
821
822
823
824
825
826
827
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
828
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
829
830
831
832
833
834
835
836
837
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
838
        window_size,
Tri Dao's avatar
Tri Dao committed
839
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
840
    )
Tri Dao's avatar
Tri Dao committed
841

842

Tri Dao's avatar
Tri Dao committed
843
844
845
846
847
848
849
850
851
852
853
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
854
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
855
856
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
857
858
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
859
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
860
861
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
862

863
864
865
866
867
868
869
870
871
872
873
874
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
875
876
877
878
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

879
    Arguments:
Tri Dao's avatar
Tri Dao committed
880
881
882
883
884
885
886
887
888
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
889
890
891
892
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
893
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
894
895
896
897
898
899
900
901
902
903
904
905
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
906
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
907
908
909
910
911
912
913
914
915
916
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
917
        window_size,
Tri Dao's avatar
Tri Dao committed
918
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
919
    )
Tri Dao's avatar
Tri Dao committed
920
921
922
923
924
925
926
927


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
928
929
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
930
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
931
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
932
933
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
934
    window_size=(-1, -1),  # -1 means infinite context window
935
    rotary_interleaved=True,
Tri Dao's avatar
Tri Dao committed
936
937
938
939
940
941
942
943
944
945
946
947
    num_splits=0,
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
948
949
950
951
952
953
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
954
955

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
974
975
976
977
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

978
979
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
980
981
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
982
983
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
984
985
986
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
987
988
989
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
990
991
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
992
993
994
995
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
996
997
998
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
999
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1000
1001
1002
1003
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1018
1019
1020
1021
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1022
1023
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1024
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1025
1026
1027
1028
1029
1030
1031
1032
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1033
        cache_batch_idx,
1034
1035
1036
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1037
1038
        window_size[0],
        window_size[1],
1039
1040
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1041
1042
    )
    return out