flash_attn_interface.py 44.3 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
47
48
def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
64
    )
65
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
66
67


Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
79
    window_size,
80
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
81
    return_softmax,
82
    block_table,
Tri Dao's avatar
Tri Dao committed
83
):
Tri Dao's avatar
Tri Dao committed
84
85
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
86
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
93
        None,
94
        block_table,
95
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
102
103
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
104
105
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
106
107
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
108
    #     breakpoint()
109
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
110
111


Tri Dao's avatar
Tri Dao committed
112
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
126
    alibi_slopes,
127
    deterministic,
Tri Dao's avatar
Tri Dao committed
128
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
129
):
Tri Dao's avatar
Tri Dao committed
130
131
132
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
133
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
139
140
141
142
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
143
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
144
145
146
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
147
148
        window_size[0],
        window_size[1],
149
        deterministic,
Tri Dao's avatar
Tri Dao committed
150
151
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
152
153
154
155
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
173
    window_size,
174
    alibi_slopes,
175
    deterministic,
Tri Dao's avatar
Tri Dao committed
176
177
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
178
179
180
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
181
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
182
183
184
185
186
187
188
189
190
191
192
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
193
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
194
195
196
197
198
199
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
200
201
        window_size[0],
        window_size[1],
202
        deterministic,
Tri Dao's avatar
Tri Dao committed
203
204
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
205
    )
Tri Dao's avatar
Tri Dao committed
206
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
207
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
208
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
209
210


Tri Dao's avatar
Tri Dao committed
211
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
212
    @staticmethod
Tri Dao's avatar
Tri Dao committed
213
    def forward(
214
215
216
217
218
219
220
221
222
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
223
    ):
Tri Dao's avatar
Tri Dao committed
224
225
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
226
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
227
228
229
230
231
232
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
233
            window_size=window_size,
234
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
235
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
236
        )
Tri Dao's avatar
Tri Dao committed
237
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
238
239
240
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
241
        ctx.window_size = window_size
242
        ctx.alibi_slopes = alibi_slopes
243
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
244
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
245
246

    @staticmethod
Tri Dao's avatar
Tri Dao committed
247
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
248
249
250
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
251
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
252
253
254
255
256
257
258
259
260
261
262
263
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
264
            ctx.window_size,
265
            ctx.alibi_slopes,
266
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
267
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
268
        )
Tri Dao's avatar
Tri Dao committed
269
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
270
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
271
272
273
274


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
284
        alibi_slopes,
285
        deterministic,
Tri Dao's avatar
Tri Dao committed
286
287
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
288
289
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
290
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
291
292
293
294
295
296
297
298
299
300
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
301
            window_size=window_size,
302
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
303
            return_softmax=return_softmax and dropout_p > 0,
304
            block_table=None,
Tri Dao's avatar
Tri Dao committed
305
        )
Tri Dao's avatar
Tri Dao committed
306
307
308
309
310
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
311
        ctx.window_size = window_size
312
        ctx.alibi_slopes = alibi_slopes
313
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
314
315
316
317
318
319
320
321
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
338
            ctx.window_size,
339
            ctx.alibi_slopes,
340
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
341
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
342
        )
Tri Dao's avatar
Tri Dao committed
343
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
344
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
345
346


Tri Dao's avatar
Tri Dao committed
347
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
348
    @staticmethod
Tri Dao's avatar
Tri Dao committed
349
    def forward(
350
351
352
353
354
355
356
357
358
359
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
360
    ):
Tri Dao's avatar
Tri Dao committed
361
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
362
            softmax_scale = q.shape[-1] ** (-0.5)
363
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
364
365
366
367
368
369
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
370
            window_size=window_size,
371
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
372
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
373
        )
Tri Dao's avatar
Tri Dao committed
374
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
375
376
377
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
378
        ctx.window_size = window_size
379
        ctx.alibi_slopes = alibi_slopes
380
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
381
382
383
384
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
385
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
386
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
387
388
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
389
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
390
391
392
393
394
395
396
397
398
399
400
401
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
402
            ctx.window_size,
403
            ctx.alibi_slopes,
404
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
405
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
406
        )
Tri Dao's avatar
Tri Dao committed
407
408
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
409
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
410
411


Tri Dao's avatar
Tri Dao committed
412
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
413
    @staticmethod
Tri Dao's avatar
Tri Dao committed
414
415
416
417
418
419
420
421
422
423
424
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
425
        window_size,
426
        alibi_slopes,
427
        deterministic,
Tri Dao's avatar
Tri Dao committed
428
429
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
430
431
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
432
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
433
434
435
436
437
438
439
440
441
442
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
443
            window_size=window_size,
444
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
445
            return_softmax=return_softmax and dropout_p > 0,
446
            block_table=None,
Tri Dao's avatar
Tri Dao committed
447
448
449
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
450
451
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
452
453
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
454
455
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
456
        ctx.window_size = window_size
457
        ctx.alibi_slopes = alibi_slopes
458
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
459
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
460
461

    @staticmethod
Tri Dao's avatar
Tri Dao committed
462
463
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
464
465
466
467
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
484
            ctx.window_size,
485
            ctx.alibi_slopes,
486
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
487
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
488
        )
Tri Dao's avatar
Tri Dao committed
489
490
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
491
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
492
493
494
495


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
496
    def forward(
497
498
499
500
501
502
503
504
505
506
507
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
508
    ):
Tri Dao's avatar
Tri Dao committed
509
510
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
511
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
512
513
514
515
516
517
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
518
            window_size=window_size,
519
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
520
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
521
522
523
524
525
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
526
        ctx.window_size = window_size
527
        ctx.alibi_slopes = alibi_slopes
528
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
529
530
531
532
533
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
534
535
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
536
537
538
539
540
541
542
543
544
545
546
547
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
548
            ctx.window_size,
549
            ctx.alibi_slopes,
550
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
551
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
552
        )
Tri Dao's avatar
Tri Dao committed
553
554
555
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
556
        return dq, dk, dv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
557
558


Tri Dao's avatar
Tri Dao committed
559
class FlashAttnVarlenFunc(torch.autograd.Function):
560
    @staticmethod
Tri Dao's avatar
Tri Dao committed
561
562
563
564
565
566
567
568
569
570
571
572
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
573
        window_size,
574
        alibi_slopes,
575
        deterministic,
Tri Dao's avatar
Tri Dao committed
576
        return_softmax,
577
        block_table,
Tri Dao's avatar
Tri Dao committed
578
    ):
579
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
580
            softmax_scale = q.shape[-1] ** (-0.5)
581
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
587
588
589
590
591
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
592
            window_size=window_size,
593
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
594
            return_softmax=return_softmax and dropout_p > 0,
595
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
596
597
598
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
599
600
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
601
602
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
603
604
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
605
        ctx.window_size = window_size
606
        ctx.alibi_slopes = alibi_slopes
607
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
608
        return out if not return_softmax else (out, softmax_lse, S_dmask)
609
610
611

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
612
613
614
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
631
            ctx.window_size,
632
            ctx.alibi_slopes,
633
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
634
            rng_state=rng_state,
635
        )
Tri Dao's avatar
Tri Dao committed
636
637
638
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
639
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
640
641


Tri Dao's avatar
Tri Dao committed
642
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
643
644
645
646
647
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
648
    alibi_slopes=None,
649
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
650
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
651
):
Tri Dao's avatar
Tri Dao committed
652
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
653
654
655
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
656
657
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
658

Tri Dao's avatar
Tri Dao committed
659
660
661
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
662
    Arguments:
Tri Dao's avatar
Tri Dao committed
663
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
664
665
666
667
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
668
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
669
670
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
671
672
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
673
674
675
676
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
677
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
678
679
680
681
682
683
684
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
685
    return FlashAttnQKVPackedFunc.apply(
686
687
688
689
690
691
692
693
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
694
    )
Tri Dao's avatar
Tri Dao committed
695
696


Tri Dao's avatar
Tri Dao committed
697
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
698
699
700
701
702
703
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
704
    alibi_slopes=None,
705
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
706
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
707
):
Tri Dao's avatar
Tri Dao committed
708
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
709
710
711
712
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
713
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
714
715
716
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

717
718
719
720
721
722
723
724
725
726
727
728
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
729
730
731
732
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
733
    Arguments:
Tri Dao's avatar
Tri Dao committed
734
735
736
737
738
739
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
740
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
741
742
743
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
744
745
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
746
747
748
749
750
751
752
753
754
755
756
757
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
758
    return FlashAttnKVPackedFunc.apply(
759
760
761
762
763
764
765
766
767
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
768
    )
Tri Dao's avatar
Tri Dao committed
769
770


Tri Dao's avatar
Tri Dao committed
771
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
772
773
774
775
776
777
778
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
779
    alibi_slopes=None,
780
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
781
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
782
):
Tri Dao's avatar
Tri Dao committed
783
784
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
785
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
786
787
788
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

789
790
791
792
793
794
795
796
797
798
799
800
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
801
802
803
804
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
805
806
807
808
809
810
811
812
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
813
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
814
815
816
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
817
818
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
819
820
821
822
823
824
825
826
827
828
829
830
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
831
    return FlashAttnFunc.apply(
832
833
834
835
836
837
838
839
840
841
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
842
    )
Tri Dao's avatar
Tri Dao committed
843
844


Tri Dao's avatar
Tri Dao committed
845
846
847
848
849
850
851
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
852
    window_size=(-1, -1),  # -1 means infinite context window
853
    alibi_slopes=None,
854
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
855
856
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
857
858
859
860
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
861
862
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
863

Tri Dao's avatar
Tri Dao committed
864
865
866
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
867
868
869
870
871
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
872
873
874
875
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
876
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
877
878
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
879
880
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
881
882
883
884
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
885
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
886
887
888
889
890
891
892
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
893
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
894
895
896
897
898
899
900
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
901
        alibi_slopes,
902
        deterministic,
Tri Dao's avatar
Tri Dao committed
903
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
904
    )
Tri Dao's avatar
Tri Dao committed
905
906


Tri Dao's avatar
Tri Dao committed
907
908
909
910
911
912
913
914
915
916
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
917
    window_size=(-1, -1),  # -1 means infinite context window
918
    alibi_slopes=None,
919
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
920
921
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
922
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
923
924
925
926
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
927
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
928
929
930
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

931
932
933
934
935
936
937
938
939
940
941
942
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
943
944
945
946
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
947
948
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
949
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
950
951
952
953
954
955
956
957
958
959
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
960
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
961
962
963
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
964
965
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
966
967
968
969
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
970
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
971
972
973
974
975
976
977
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
978
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
979
980
981
982
983
984
985
986
987
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
988
        window_size,
989
        alibi_slopes,
990
        deterministic,
Tri Dao's avatar
Tri Dao committed
991
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
992
    )
Tri Dao's avatar
Tri Dao committed
993

994

Tri Dao's avatar
Tri Dao committed
995
996
997
998
999
1000
1001
1002
1003
1004
1005
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1006
    window_size=(-1, -1),  # -1 means infinite context window
1007
    alibi_slopes=None,
1008
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1009
    return_attn_probs=False,
1010
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1011
):
Tri Dao's avatar
Tri Dao committed
1012
1013
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1014
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1015
1016
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1017

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1030
1031
1032
1033
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1034
    Arguments:
Tri Dao's avatar
Tri Dao committed
1035
1036
1037
1038
1039
1040
1041
1042
1043
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1044
1045
1046
1047
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1048
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1049
1050
1051
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1052
1053
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1066
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1077
        window_size,
1078
        alibi_slopes,
1079
        deterministic,
Tri Dao's avatar
Tri Dao committed
1080
        return_attn_probs,
1081
        block_table,
Tri Dao's avatar
Tri Dao committed
1082
    )
Tri Dao's avatar
Tri Dao committed
1083
1084
1085
1086
1087
1088
1089
1090


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1091
1092
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1093
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1094
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1095
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1096
1097
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1098
    window_size=(-1, -1),  # -1 means infinite context window
1099
    rotary_interleaved=True,
1100
    alibi_slopes=None,
1101
    num_splits=0,
Tri Dao's avatar
Tri Dao committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1113
1114
1115
1116
1117
1118
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1119
1120

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1139
1140
1141
1142
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1143
1144
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1145
1146
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1147
1148
1149
1150
1151
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1152
1153
1154
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1155
1156
1157
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1158
1159
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1160
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1161
1162
1163
1164
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1165
1166
1167
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1168
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1169
1170
1171
1172
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1173
1174
1175
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1190
1191
1192
1193
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1194
1195
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1196
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1197
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1198
1199
1200
1201
1202
1203
1204
1205
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1206
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1207
        block_table,
1208
        alibi_slopes,
1209
1210
1211
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1212
1213
        window_size[0],
        window_size[1],
1214
1215
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1216
1217
    )
    return out