flash_attn_interface.py 38.8 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
47
48
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
49
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
56
57
58
59
60
        q,
        k,
        v,
        None,
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
61
    )
62
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
63
64


Tri Dao's avatar
Tri Dao committed
65
66
67
68
69
70
71
72
73
74
75
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
76
    window_size,
Tri Dao's avatar
Tri Dao committed
77
78
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
79
80
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
81
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
82
83
84
85
86
87
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
88
        None,
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
95
96
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
97
98
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
99
100
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
101
    #     breakpoint()
102
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
103
104


Tri Dao's avatar
Tri Dao committed
105
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
120
):
Tri Dao's avatar
Tri Dao committed
121
122
123
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
124
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
125
126
127
128
129
130
131
132
133
134
135
136
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
137
138
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
139
140
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
141
142
143
144
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
162
    window_size,
Tri Dao's avatar
Tri Dao committed
163
164
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
165
166
167
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
168
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
186
187
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
188
189
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
190
    )
Tri Dao's avatar
Tri Dao committed
191
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
192
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
193
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
194
195


Tri Dao's avatar
Tri Dao committed
196
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
197
    @staticmethod
Tri Dao's avatar
Tri Dao committed
198
    def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
199
200
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
201
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
202
203
204
205
206
207
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
208
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
209
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
210
        )
Tri Dao's avatar
Tri Dao committed
211
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
212
213
214
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
215
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
216
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
217
218

    @staticmethod
Tri Dao's avatar
Tri Dao committed
219
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
220
221
222
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
223
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
224
225
226
227
228
229
230
231
232
233
234
235
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
236
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
237
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
238
        )
Tri Dao's avatar
Tri Dao committed
239
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
240
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
241
242
243
244


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
245
246
247
248
249
250
251
252
253
254
255
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
256
257
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
258
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
264
265
266
267
268
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
269
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
270
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
271
        )
Tri Dao's avatar
Tri Dao committed
272
273
274
275
276
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
277
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
283
284
285
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
302
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
303
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
304
        )
Tri Dao's avatar
Tri Dao committed
305
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
Tri Dao's avatar
Tri Dao committed
306
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
307
308


Tri Dao's avatar
Tri Dao committed
309
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
310
    @staticmethod
Tri Dao's avatar
Tri Dao committed
311
    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
312
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
313
            softmax_scale = q.shape[-1] ** (-0.5)
314
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
315
316
317
318
319
320
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
321
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
322
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
323
        )
Tri Dao's avatar
Tri Dao committed
324
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
325
326
327
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
328
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
329
330
331
332
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
333
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
334
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
335
336
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
337
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
338
339
340
341
342
343
344
345
346
347
348
349
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
350
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
351
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
352
        )
Tri Dao's avatar
Tri Dao committed
353
354
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
355
        return dq, dkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
356
357


Tri Dao's avatar
Tri Dao committed
358
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
359
    @staticmethod
Tri Dao's avatar
Tri Dao committed
360
361
362
363
364
365
366
367
368
369
370
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
371
        window_size,
Tri Dao's avatar
Tri Dao committed
372
373
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
374
375
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
376
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
377
378
379
380
381
382
383
384
385
386
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
387
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
388
389
390
391
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
392
393
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
394
395
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
396
397
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
398
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
399
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
400
401

    @staticmethod
Tri Dao's avatar
Tri Dao committed
402
403
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
404
405
406
407
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
424
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
425
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
426
        )
Tri Dao's avatar
Tri Dao committed
427
428
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
429
        return dq, dkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
430
431
432
433


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
434
    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
Tri Dao's avatar
Tri Dao committed
435
436
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
437
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
438
439
440
441
442
443
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
444
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
445
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
446
447
448
449
450
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
451
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
452
453
454
455
456
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
457
458
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
459
460
461
462
463
464
465
466
467
468
469
470
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
471
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
472
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
473
        )
Tri Dao's avatar
Tri Dao committed
474
475
476
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
477
        return dq, dk, dv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
478
479


Tri Dao's avatar
Tri Dao committed
480
class FlashAttnVarlenFunc(torch.autograd.Function):
481
    @staticmethod
Tri Dao's avatar
Tri Dao committed
482
483
484
485
486
487
488
489
490
491
492
493
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
494
        window_size,
Tri Dao's avatar
Tri Dao committed
495
496
        return_softmax,
    ):
497
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
498
            softmax_scale = q.shape[-1] ** (-0.5)
499
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
500
501
502
503
504
505
506
507
508
509
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
510
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
511
512
513
514
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
515
516
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
517
518
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
519
520
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
521
        ctx.window_size = window_size
Tri Dao's avatar
Tri Dao committed
522
        return out if not return_softmax else (out, softmax_lse, S_dmask)
523
524
525

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
526
527
528
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
545
            ctx.window_size,
Tri Dao's avatar
Tri Dao committed
546
            rng_state=rng_state,
547
        )
Tri Dao's avatar
Tri Dao committed
548
549
550
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
Tri Dao's avatar
Tri Dao committed
551
        return dq, dk, dv, None, None, None, None, None, None, None, None, None
552
553


Tri Dao's avatar
Tri Dao committed
554
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
555
556
557
558
559
560
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
561
):
Tri Dao's avatar
Tri Dao committed
562
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
563
564
565
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
566
567
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
568

Tri Dao's avatar
Tri Dao committed
569
570
571
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
572
    Arguments:
Tri Dao's avatar
Tri Dao committed
573
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
574
575
576
577
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
578
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
579
580
581
582
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
583
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
584
585
586
587
588
589
590
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
591
592
593
    return FlashAttnQKVPackedFunc.apply(
        qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
594
595


Tri Dao's avatar
Tri Dao committed
596
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
597
598
599
600
601
602
603
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
604
):
Tri Dao's avatar
Tri Dao committed
605
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
606
607
608
609
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
610
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
611
612
613
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

614
615
616
617
618
619
620
621
622
623
624
625
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
626
627
628
629
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
630
    Arguments:
Tri Dao's avatar
Tri Dao committed
631
632
633
634
635
636
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
637
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
638
639
640
641
642
643
644
645
646
647
648
649
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
650
651
652
    return FlashAttnKVPackedFunc.apply(
        q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
653
654


Tri Dao's avatar
Tri Dao committed
655
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
656
657
658
659
660
661
662
663
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
664
):
Tri Dao's avatar
Tri Dao committed
665
666
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
667
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
668
669
670
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

671
672
673
674
675
676
677
678
679
680
681
682
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
683
684
685
686
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
687
688
689
690
691
692
693
694
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
695
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
696
697
698
699
700
701
702
703
704
705
706
707
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
708
709
710
    return FlashAttnFunc.apply(
        q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs
    )
Tri Dao's avatar
Tri Dao committed
711
712


Tri Dao's avatar
Tri Dao committed
713
714
715
716
717
718
719
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
720
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
721
722
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
723
724
725
726
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
727
728
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
729

Tri Dao's avatar
Tri Dao committed
730
731
732
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
733
734
735
736
737
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
738
739
740
741
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
742
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
743
744
745
746
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
747
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
748
749
750
751
752
753
754
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
755
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
756
757
758
759
760
761
762
763
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
764
    )
Tri Dao's avatar
Tri Dao committed
765
766


Tri Dao's avatar
Tri Dao committed
767
768
769
770
771
772
773
774
775
776
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
777
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
778
779
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
780
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
781
782
783
784
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
785
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
786
787
788
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

789
790
791
792
793
794
795
796
797
798
799
800
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
801
802
803
804
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
805
806
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
807
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
808
809
810
811
812
813
814
815
816
817
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
818
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
819
820
821
822
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
823
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
824
825
826
827
828
829
830
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
831
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
832
833
834
835
836
837
838
839
840
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
841
        window_size,
Tri Dao's avatar
Tri Dao committed
842
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
843
    )
Tri Dao's avatar
Tri Dao committed
844

845

Tri Dao's avatar
Tri Dao committed
846
847
848
849
850
851
852
853
854
855
856
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
857
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
858
859
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
860
861
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
862
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
863
864
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
865

866
867
868
869
870
871
872
873
874
875
876
877
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
878
879
880
881
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

882
    Arguments:
Tri Dao's avatar
Tri Dao committed
883
884
885
886
887
888
889
890
891
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
892
893
894
895
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
896
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
897
898
899
900
901
902
903
904
905
906
907
908
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
909
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
910
911
912
913
914
915
916
917
918
919
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
920
        window_size,
Tri Dao's avatar
Tri Dao committed
921
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
922
    )
Tri Dao's avatar
Tri Dao committed
923
924
925
926
927
928
929
930


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
931
932
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
933
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
934
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
935
936
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
937
    window_size=(-1, -1),  # -1 means infinite context window
938
    rotary_interleaved=True,
Tri Dao's avatar
Tri Dao committed
939
940
941
942
943
944
945
946
947
948
949
950
    num_splits=0,
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
951
952
953
954
955
956
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
957
958

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
977
978
979
980
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

981
982
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
983
984
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
985
986
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
987
988
989
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
990
991
992
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
993
994
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
995
996
997
998
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
999
1000
1001
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1002
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1003
1004
1005
1006
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1021
1022
1023
1024
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1025
1026
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1027
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1028
1029
1030
1031
1032
1033
1034
1035
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1036
        cache_batch_idx,
1037
1038
1039
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1040
1041
        window_size[0],
        window_size[1],
1042
1043
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1044
1045
    )
    return out