flash_attn_interface.py 40 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


46
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
Tri Dao's avatar
Tri Dao committed
47
48
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
49
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
56
57
58
        q,
        k,
        v,
        None,
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
59
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
60
61
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
62
    )
63
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
64
65


Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
71
72
73
74
75
76
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
77
    window_size,
78
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
79
80
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
81
82
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
83
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
84
85
86
87
88
89
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
90
        None,
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
97
98
        window_size[0],
        window_size[1],
99
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
100
101
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
102
103
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
104
    #     breakpoint()
105
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
106
107


Tri Dao's avatar
Tri Dao committed
108
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
115
116
117
118
119
120
121
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
122
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
123
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
124
):
Tri Dao's avatar
Tri Dao committed
125
126
127
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
128
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
135
136
137
138
139
140
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
141
142
        window_size[0],
        window_size[1],
143
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
144
145
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
146
147
148
149
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
167
    window_size,
168
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
169
170
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
171
172
173
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
174
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
192
193
        window_size[0],
        window_size[1],
194
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
195
196
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
197
    )
Tri Dao's avatar
Tri Dao committed
198
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
199
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
200
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
201
202


Tri Dao's avatar
Tri Dao committed
203
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
204
    @staticmethod
205
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
Tri Dao's avatar
Tri Dao committed
206
207
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
208
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
209
210
211
212
213
214
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
215
            window_size=window_size,
216
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
217
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
218
        )
Tri Dao's avatar
Tri Dao committed
219
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
220
221
222
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
223
        ctx.window_size = window_size
224
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
225
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
226
227

    @staticmethod
Tri Dao's avatar
Tri Dao committed
228
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
229
230
231
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
232
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
238
239
240
241
242
243
244
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
245
            ctx.window_size,
246
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
247
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
248
        )
Tri Dao's avatar
Tri Dao committed
249
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
250
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
251
252
253
254


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
255
256
257
258
259
260
261
262
263
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
264
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
265
266
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
267
268
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
269
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
270
271
272
273
274
275
276
277
278
279
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
280
            window_size=window_size,
281
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
282
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
283
        )
Tri Dao's avatar
Tri Dao committed
284
285
286
287
288
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
289
        ctx.window_size = window_size
290
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
291
292
293
294
295
296
297
298
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
315
            ctx.window_size,
316
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
317
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
318
        )
Tri Dao's avatar
Tri Dao committed
319
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
320
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
321
322


Tri Dao's avatar
Tri Dao committed
323
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
324
    @staticmethod
325
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
Tri Dao's avatar
Tri Dao committed
326
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
327
            softmax_scale = q.shape[-1] ** (-0.5)
328
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
329
330
331
332
333
334
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
335
            window_size=window_size,
336
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
337
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
338
        )
Tri Dao's avatar
Tri Dao committed
339
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
340
341
342
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
343
        ctx.window_size = window_size
344
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
345
346
347
348
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
349
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
350
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
351
352
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
353
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
354
355
356
357
358
359
360
361
362
363
364
365
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
366
            ctx.window_size,
367
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
368
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
369
        )
Tri Dao's avatar
Tri Dao committed
370
371
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
372
        return dq, dkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
373
374


Tri Dao's avatar
Tri Dao committed
375
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
376
    @staticmethod
Tri Dao's avatar
Tri Dao committed
377
378
379
380
381
382
383
384
385
386
387
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
388
        window_size,
389
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
390
391
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
392
393
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
394
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
395
396
397
398
399
400
401
402
403
404
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
405
            window_size=window_size,
406
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
407
408
409
410
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
411
412
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
413
414
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
415
416
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
417
        ctx.window_size = window_size
418
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
419
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
420
421

    @staticmethod
Tri Dao's avatar
Tri Dao committed
422
423
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
424
425
426
427
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
444
            ctx.window_size,
445
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
446
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
447
        )
Tri Dao's avatar
Tri Dao committed
448
449
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
450
        return dq, dkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
451
452
453
454


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
455
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
Tri Dao's avatar
Tri Dao committed
456
457
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
458
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
459
460
461
462
463
464
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
465
            window_size=window_size,
466
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
467
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
468
469
470
471
472
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
473
        ctx.window_size = window_size
474
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
475
476
477
478
479
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
480
481
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
482
483
484
485
486
487
488
489
490
491
492
493
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
494
            ctx.window_size,
495
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
496
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
497
        )
Tri Dao's avatar
Tri Dao committed
498
499
500
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
501
        return dq, dk, dv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
502
503


Tri Dao's avatar
Tri Dao committed
504
class FlashAttnVarlenFunc(torch.autograd.Function):
505
    @staticmethod
Tri Dao's avatar
Tri Dao committed
506
507
508
509
510
511
512
513
514
515
516
517
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
518
        window_size,
519
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
520
521
        return_softmax,
    ):
522
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
523
            softmax_scale = q.shape[-1] ** (-0.5)
524
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
525
526
527
528
529
530
531
532
533
534
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
535
            window_size=window_size,
536
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
537
538
539
540
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
541
542
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
543
544
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
545
546
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
547
        ctx.window_size = window_size
548
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
549
        return out if not return_softmax else (out, softmax_lse, S_dmask)
550
551
552

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
553
554
555
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
572
            ctx.window_size,
573
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
574
            rng_state=rng_state,
575
        )
Tri Dao's avatar
Tri Dao committed
576
577
578
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
579
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
580
581


Tri Dao's avatar
Tri Dao committed
582
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
583
584
585
586
587
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
588
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
589
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
590
):
Tri Dao's avatar
Tri Dao committed
591
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
592
593
594
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
595
596
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
597

Tri Dao's avatar
Tri Dao committed
598
599
600
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
601
    Arguments:
Tri Dao's avatar
Tri Dao committed
602
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
603
604
605
606
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
607
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
608
609
610
611
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
612
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
613
614
615
616
617
618
619
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
620
    return FlashAttnQKVPackedFunc.apply(
621
        qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
622
    )
Tri Dao's avatar
Tri Dao committed
623
624


Tri Dao's avatar
Tri Dao committed
625
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
626
627
628
629
630
631
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
632
    alibi_slopes=None, 
Tri Dao's avatar
Tri Dao committed
633
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
634
):
Tri Dao's avatar
Tri Dao committed
635
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
636
637
638
639
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
640
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
641
642
643
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

644
645
646
647
648
649
650
651
652
653
654
655
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
656
657
658
659
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
660
    Arguments:
Tri Dao's avatar
Tri Dao committed
661
662
663
664
665
666
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
667
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
668
669
670
671
672
673
674
675
676
677
678
679
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
680
    return FlashAttnKVPackedFunc.apply(
681
        q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
682
    )
Tri Dao's avatar
Tri Dao committed
683
684


Tri Dao's avatar
Tri Dao committed
685
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
686
687
688
689
690
691
692
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
693
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
694
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
695
):
Tri Dao's avatar
Tri Dao committed
696
697
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
698
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
699
700
701
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

702
703
704
705
706
707
708
709
710
711
712
713
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
714
715
716
717
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
718
719
720
721
722
723
724
725
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
726
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
727
728
729
730
731
732
733
734
735
736
737
738
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
739
    return FlashAttnFunc.apply(
740
        q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
741
    )
Tri Dao's avatar
Tri Dao committed
742
743


Tri Dao's avatar
Tri Dao committed
744
745
746
747
748
749
750
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
751
    window_size=(-1, -1),  # -1 means infinite context window
752
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
753
754
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
755
756
757
758
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
759
760
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
761

Tri Dao's avatar
Tri Dao committed
762
763
764
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
765
766
767
768
769
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
770
771
772
773
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
774
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
775
776
777
778
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
779
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
780
781
782
783
784
785
786
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
787
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
788
789
790
791
792
793
794
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
795
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
796
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
797
    )
Tri Dao's avatar
Tri Dao committed
798
799


Tri Dao's avatar
Tri Dao committed
800
801
802
803
804
805
806
807
808
809
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
810
    window_size=(-1, -1),  # -1 means infinite context window
811
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
812
813
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
814
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
815
816
817
818
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
819
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
820
821
822
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

823
824
825
826
827
828
829
830
831
832
833
834
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
835
836
837
838
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
839
840
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
841
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
842
843
844
845
846
847
848
849
850
851
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
852
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
853
854
855
856
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
857
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
858
859
860
861
862
863
864
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
865
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
866
867
868
869
870
871
872
873
874
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
875
        window_size,
876
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
877
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
878
    )
Tri Dao's avatar
Tri Dao committed
879

880

Tri Dao's avatar
Tri Dao committed
881
882
883
884
885
886
887
888
889
890
891
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
892
    window_size=(-1, -1),  # -1 means infinite context window
893
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
894
895
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
896
897
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
898
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
899
900
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
901

902
903
904
905
906
907
908
909
910
911
912
913
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
914
915
916
917
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

918
    Arguments:
Tri Dao's avatar
Tri Dao committed
919
920
921
922
923
924
925
926
927
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
928
929
930
931
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
932
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
933
934
935
936
937
938
939
940
941
942
943
944
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
945
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
946
947
948
949
950
951
952
953
954
955
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
956
        window_size,
957
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
958
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
959
    )
Tri Dao's avatar
Tri Dao committed
960
961
962
963
964
965
966
967


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
968
969
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
970
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
971
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
972
973
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
974
    window_size=(-1, -1),  # -1 means infinite context window
975
    rotary_interleaved=True,
Tri Dao's avatar
Tri Dao committed
976
    num_splits=0,
977
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
978
979
980
981
982
983
984
985
986
987
988
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
989
990
991
992
993
994
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
995
996

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1015
1016
1017
1018
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1019
1020
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1021
1022
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
1023
1024
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
1025
1026
1027
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1028
1029
1030
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1031
1032
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1033
1034
1035
1036
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1037
1038
1039
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1040
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1041
1042
1043
1044
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1059
1060
1061
1062
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1063
1064
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1065
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1066
1067
1068
1069
1070
1071
1072
1073
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1074
        cache_batch_idx,
1075
1076
1077
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1078
1079
        window_size[0],
        window_size[1],
1080
1081
        rotary_interleaved,
        num_splits,
1082
        alibi_slopes
Tri Dao's avatar
Tri Dao committed
1083
1084
    )
    return out