flash_attn_interface.py 40 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


Tri Dao's avatar
Tri Dao committed
46
47
48
def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
        q,
        k,
        v,
        None,
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
61
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
62
63
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
64
    )
65
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
66
67


Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
79
    window_size,
80
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
81
82
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
83
84
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
85
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
92
        None,
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
99
100
        window_size[0],
        window_size[1],
101
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
102
103
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
104
105
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
106
    #     breakpoint()
107
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
108
109


Tri Dao's avatar
Tri Dao committed
110
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
124
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
125
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
126
):
Tri Dao's avatar
Tri Dao committed
127
128
129
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
130
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
139
140
141
142
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
143
144
        window_size[0],
        window_size[1],
145
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
146
147
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
148
149
150
151
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
169
    window_size,
170
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
171
172
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
173
174
175
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
176
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
194
195
        window_size[0],
        window_size[1],
196
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
197
198
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
199
    )
Tri Dao's avatar
Tri Dao committed
200
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
201
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
202
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
203
204


Tri Dao's avatar
Tri Dao committed
205
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
206
    @staticmethod
Tri Dao's avatar
Tri Dao committed
207
208
209
    def forward(
        ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
    ):
Tri Dao's avatar
Tri Dao committed
210
211
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
212
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
213
214
215
216
217
218
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
219
            window_size=window_size,
220
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
221
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
222
        )
Tri Dao's avatar
Tri Dao committed
223
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
224
225
226
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
227
        ctx.window_size = window_size
228
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
229
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
230
231

    @staticmethod
Tri Dao's avatar
Tri Dao committed
232
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
233
234
235
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
236
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
237
238
239
240
241
242
243
244
245
246
247
248
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
249
            ctx.window_size,
250
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
251
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
252
        )
Tri Dao's avatar
Tri Dao committed
253
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
254
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
255
256
257
258


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
264
265
266
267
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
268
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
269
270
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
271
272
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
273
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
274
275
276
277
278
279
280
281
282
283
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
284
            window_size=window_size,
285
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
286
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
287
        )
Tri Dao's avatar
Tri Dao committed
288
289
290
291
292
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
293
        ctx.window_size = window_size
294
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
295
296
297
298
299
300
301
302
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
319
            ctx.window_size,
320
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
321
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
322
        )
Tri Dao's avatar
Tri Dao committed
323
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
324
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
325
326


Tri Dao's avatar
Tri Dao committed
327
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
328
    @staticmethod
Tri Dao's avatar
Tri Dao committed
329
330
331
    def forward(
        ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
    ):
Tri Dao's avatar
Tri Dao committed
332
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
333
            softmax_scale = q.shape[-1] ** (-0.5)
334
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
335
336
337
338
339
340
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
341
            window_size=window_size,
342
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
343
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
344
        )
Tri Dao's avatar
Tri Dao committed
345
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
346
347
348
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
349
        ctx.window_size = window_size
350
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
351
352
353
354
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
355
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
356
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
357
358
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
359
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
360
361
362
363
364
365
366
367
368
369
370
371
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
372
            ctx.window_size,
373
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
374
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
375
        )
Tri Dao's avatar
Tri Dao committed
376
377
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
378
        return dq, dkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
379
380


Tri Dao's avatar
Tri Dao committed
381
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
382
    @staticmethod
Tri Dao's avatar
Tri Dao committed
383
384
385
386
387
388
389
390
391
392
393
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
394
        window_size,
395
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
396
397
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
398
399
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
400
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
401
402
403
404
405
406
407
408
409
410
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
411
            window_size=window_size,
412
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
413
414
415
416
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
417
418
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
419
420
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
421
422
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
423
        ctx.window_size = window_size
424
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
425
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
426
427

    @staticmethod
Tri Dao's avatar
Tri Dao committed
428
429
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
430
431
432
433
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
450
            ctx.window_size,
451
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
452
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
453
        )
Tri Dao's avatar
Tri Dao committed
454
455
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
456
        return dq, dkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
457
458
459
460


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
461
462
463
    def forward(
        ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
    ):
Tri Dao's avatar
Tri Dao committed
464
465
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
466
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
467
468
469
470
471
472
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
473
            window_size=window_size,
474
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
475
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
476
477
478
479
480
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
481
        ctx.window_size = window_size
482
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
483
484
485
486
487
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
488
489
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
490
491
492
493
494
495
496
497
498
499
500
501
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
502
            ctx.window_size,
503
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
504
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
505
        )
Tri Dao's avatar
Tri Dao committed
506
507
508
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
509
        return dq, dk, dv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
510
511


Tri Dao's avatar
Tri Dao committed
512
class FlashAttnVarlenFunc(torch.autograd.Function):
513
    @staticmethod
Tri Dao's avatar
Tri Dao committed
514
515
516
517
518
519
520
521
522
523
524
525
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
526
        window_size,
527
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
528
529
        return_softmax,
    ):
530
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
531
            softmax_scale = q.shape[-1] ** (-0.5)
532
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
533
534
535
536
537
538
539
540
541
542
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
543
            window_size=window_size,
544
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
545
546
547
548
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
549
550
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
551
552
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
553
554
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
555
        ctx.window_size = window_size
556
        ctx.alibi_slopes = alibi_slopes
Tri Dao's avatar
Tri Dao committed
557
        return out if not return_softmax else (out, softmax_lse, S_dmask)
558
559
560

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
561
562
563
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
580
            ctx.window_size,
581
            ctx.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
582
            rng_state=rng_state,
583
        )
Tri Dao's avatar
Tri Dao committed
584
585
586
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
587
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
588
589


Tri Dao's avatar
Tri Dao committed
590
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
591
592
593
594
595
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
596
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
597
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
598
):
Tri Dao's avatar
Tri Dao committed
599
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
600
601
602
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
603
604
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
605

Tri Dao's avatar
Tri Dao committed
606
607
608
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
609
    Arguments:
Tri Dao's avatar
Tri Dao committed
610
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
611
612
613
614
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
615
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
616
617
618
619
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
620
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
621
622
623
624
625
626
627
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
628
    return FlashAttnQKVPackedFunc.apply(
629
        qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
630
    )
Tri Dao's avatar
Tri Dao committed
631
632


Tri Dao's avatar
Tri Dao committed
633
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
634
635
636
637
638
639
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
640
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
641
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
642
):
Tri Dao's avatar
Tri Dao committed
643
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
644
645
646
647
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
648
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
649
650
651
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

652
653
654
655
656
657
658
659
660
661
662
663
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
664
665
666
667
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
668
    Arguments:
Tri Dao's avatar
Tri Dao committed
669
670
671
672
673
674
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
675
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
676
677
678
679
680
681
682
683
684
685
686
687
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
688
    return FlashAttnKVPackedFunc.apply(
689
        q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
690
    )
Tri Dao's avatar
Tri Dao committed
691
692


Tri Dao's avatar
Tri Dao committed
693
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
694
695
696
697
698
699
700
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
701
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
702
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
703
):
Tri Dao's avatar
Tri Dao committed
704
705
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
706
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
707
708
709
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

710
711
712
713
714
715
716
717
718
719
720
721
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
722
723
724
725
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
726
727
728
729
730
731
732
733
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
734
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
735
736
737
738
739
740
741
742
743
744
745
746
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
747
    return FlashAttnFunc.apply(
748
        q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
Tri Dao's avatar
Tri Dao committed
749
    )
Tri Dao's avatar
Tri Dao committed
750
751


Tri Dao's avatar
Tri Dao committed
752
753
754
755
756
757
758
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
759
    window_size=(-1, -1),  # -1 means infinite context window
760
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
761
762
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
763
764
765
766
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
767
768
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
769

Tri Dao's avatar
Tri Dao committed
770
771
772
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
773
774
775
776
777
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
778
779
780
781
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
782
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
783
784
785
786
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
787
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
788
789
790
791
792
793
794
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
795
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
796
797
798
799
800
801
802
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
803
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
804
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
805
    )
Tri Dao's avatar
Tri Dao committed
806
807


Tri Dao's avatar
Tri Dao committed
808
809
810
811
812
813
814
815
816
817
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
818
    window_size=(-1, -1),  # -1 means infinite context window
819
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
820
821
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
822
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
823
824
825
826
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
827
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
828
829
830
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

831
832
833
834
835
836
837
838
839
840
841
842
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
843
844
845
846
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
847
848
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
849
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
850
851
852
853
854
855
856
857
858
859
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
860
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Tri Dao's avatar
Tri Dao committed
861
862
863
864
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
865
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
866
867
868
869
870
871
872
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
873
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
874
875
876
877
878
879
880
881
882
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
883
        window_size,
884
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
885
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
886
    )
Tri Dao's avatar
Tri Dao committed
887

888

Tri Dao's avatar
Tri Dao committed
889
890
891
892
893
894
895
896
897
898
899
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
900
    window_size=(-1, -1),  # -1 means infinite context window
901
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
902
903
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
904
905
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
906
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
907
908
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
909

910
911
912
913
914
915
916
917
918
919
920
921
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
922
923
924
925
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

926
    Arguments:
Tri Dao's avatar
Tri Dao committed
927
928
929
930
931
932
933
934
935
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
936
937
938
939
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
940
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
941
942
943
944
945
946
947
948
949
950
951
952
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
953
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
954
955
956
957
958
959
960
961
962
963
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
964
        window_size,
965
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
966
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
967
    )
Tri Dao's avatar
Tri Dao committed
968
969
970
971
972
973
974
975


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
976
977
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
978
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
979
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
980
981
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
982
    window_size=(-1, -1),  # -1 means infinite context window
983
    rotary_interleaved=True,
Tri Dao's avatar
Tri Dao committed
984
    num_splits=0,
985
    alibi_slopes=None,
Tri Dao's avatar
Tri Dao committed
986
987
988
989
990
991
992
993
994
995
996
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
997
998
999
1000
1001
1002
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1003
1004

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1023
1024
1025
1026
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1027
1028
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1029
1030
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
1031
1032
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
1033
1034
1035
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1036
1037
1038
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1039
1040
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1041
1042
1043
1044
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1045
1046
1047
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1048
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1049
1050
1051
1052
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1067
1068
1069
1070
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1071
1072
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1073
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1074
1075
1076
1077
1078
1079
1080
1081
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1082
        cache_batch_idx,
1083
1084
1085
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1086
1087
        window_size[0],
        window_size[1],
1088
1089
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1090
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
1091
1092
    )
    return out