flash_attn_interface.py 44.3 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
47
48
def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
64
    )
65
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
66
67


Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
79
    window_size,
80
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
81
    return_softmax,
82
    block_table,
Tri Dao's avatar
Tri Dao committed
83
):
Tri Dao's avatar
Tri Dao committed
84
85
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
86
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
93
        None,
94
        block_table,
95
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
102
103
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
104
105
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
106
107
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
108
    #     breakpoint()
109
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
110
111


Tri Dao's avatar
Tri Dao committed
112
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
126
    alibi_slopes,
127
    deterministic,
Tri Dao's avatar
Tri Dao committed
128
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
129
):
Tri Dao's avatar
Tri Dao committed
130
131
132
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
133
134
135
136
137
138
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
144
145
146
147
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
148
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
149
150
151
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
152
153
        window_size[0],
        window_size[1],
154
        deterministic,
Tri Dao's avatar
Tri Dao committed
155
156
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
157
158
159
160
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
178
    window_size,
179
    alibi_slopes,
180
    deterministic,
Tri Dao's avatar
Tri Dao committed
181
182
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
183
184
185
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
186
187
188
189
190
191
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
197
198
199
200
201
202
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
203
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
204
205
206
207
208
209
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
210
211
        window_size[0],
        window_size[1],
212
        deterministic,
Tri Dao's avatar
Tri Dao committed
213
214
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
215
    )
Tri Dao's avatar
Tri Dao committed
216
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
217
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
218
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
219
220


Tri Dao's avatar
Tri Dao committed
221
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
222
    @staticmethod
Tri Dao's avatar
Tri Dao committed
223
    def forward(
224
225
226
227
228
229
230
231
232
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
233
    ):
Tri Dao's avatar
Tri Dao committed
234
235
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
236
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
237
238
239
240
241
242
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
243
            window_size=window_size,
244
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
245
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
246
        )
Tri Dao's avatar
Tri Dao committed
247
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
248
249
250
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
251
        ctx.window_size = window_size
252
        ctx.alibi_slopes = alibi_slopes
253
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
254
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
255
256

    @staticmethod
Tri Dao's avatar
Tri Dao committed
257
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
258
259
260
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
261
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
262
263
264
265
266
267
268
269
270
271
272
273
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
274
            ctx.window_size,
275
            ctx.alibi_slopes,
276
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
277
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
278
        )
Tri Dao's avatar
Tri Dao committed
279
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
280
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
281
282
283
284


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
285
286
287
288
289
290
291
292
293
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
294
        alibi_slopes,
295
        deterministic,
Tri Dao's avatar
Tri Dao committed
296
297
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
298
299
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
300
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
301
302
303
304
305
306
307
308
309
310
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
311
            window_size=window_size,
312
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
313
            return_softmax=return_softmax and dropout_p > 0,
314
            block_table=None,
Tri Dao's avatar
Tri Dao committed
315
        )
Tri Dao's avatar
Tri Dao committed
316
317
318
319
320
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
321
        ctx.window_size = window_size
322
        ctx.alibi_slopes = alibi_slopes
323
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
324
325
326
327
328
329
330
331
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
348
            ctx.window_size,
349
            ctx.alibi_slopes,
350
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
351
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
352
        )
Tri Dao's avatar
Tri Dao committed
353
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
354
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
355
356


Tri Dao's avatar
Tri Dao committed
357
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
358
    @staticmethod
Tri Dao's avatar
Tri Dao committed
359
    def forward(
360
361
362
363
364
365
366
367
368
369
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
370
    ):
Tri Dao's avatar
Tri Dao committed
371
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
372
            softmax_scale = q.shape[-1] ** (-0.5)
373
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
374
375
376
377
378
379
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
380
            window_size=window_size,
381
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
382
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
383
        )
Tri Dao's avatar
Tri Dao committed
384
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
385
386
387
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
388
        ctx.window_size = window_size
389
        ctx.alibi_slopes = alibi_slopes
390
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
391
392
393
394
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
395
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
396
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
397
398
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
399
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
406
407
408
409
410
411
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
412
            ctx.window_size,
413
            ctx.alibi_slopes,
414
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
415
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
416
        )
Tri Dao's avatar
Tri Dao committed
417
418
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
419
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
420
421


Tri Dao's avatar
Tri Dao committed
422
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
423
    @staticmethod
Tri Dao's avatar
Tri Dao committed
424
425
426
427
428
429
430
431
432
433
434
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
435
        window_size,
436
        alibi_slopes,
437
        deterministic,
Tri Dao's avatar
Tri Dao committed
438
439
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
440
441
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
442
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
443
444
445
446
447
448
449
450
451
452
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
453
            window_size=window_size,
454
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
455
            return_softmax=return_softmax and dropout_p > 0,
456
            block_table=None,
Tri Dao's avatar
Tri Dao committed
457
458
459
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
460
461
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
462
463
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
464
465
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
466
        ctx.window_size = window_size
467
        ctx.alibi_slopes = alibi_slopes
468
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
469
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
470
471

    @staticmethod
Tri Dao's avatar
Tri Dao committed
472
473
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
474
475
476
477
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
494
            ctx.window_size,
495
            ctx.alibi_slopes,
496
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
497
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
498
        )
Tri Dao's avatar
Tri Dao committed
499
500
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
501
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
502
503
504
505


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
506
    def forward(
507
508
509
510
511
512
513
514
515
516
517
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
518
    ):
Tri Dao's avatar
Tri Dao committed
519
520
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
521
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
522
523
524
525
526
527
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
528
            window_size=window_size,
529
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
530
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
531
532
533
534
535
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
536
        ctx.window_size = window_size
537
        ctx.alibi_slopes = alibi_slopes
538
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
539
540
541
542
543
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
544
545
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
546
547
548
549
550
551
552
553
554
555
556
557
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
558
            ctx.window_size,
559
            ctx.alibi_slopes,
560
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
561
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
562
        )
Tri Dao's avatar
Tri Dao committed
563
564
565
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
566
        return dq, dk, dv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
567
568


Tri Dao's avatar
Tri Dao committed
569
class FlashAttnVarlenFunc(torch.autograd.Function):
570
    @staticmethod
Tri Dao's avatar
Tri Dao committed
571
572
573
574
575
576
577
578
579
580
581
582
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
583
        window_size,
584
        alibi_slopes,
585
        deterministic,
Tri Dao's avatar
Tri Dao committed
586
        return_softmax,
587
        block_table,
Tri Dao's avatar
Tri Dao committed
588
    ):
589
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
590
            softmax_scale = q.shape[-1] ** (-0.5)
591
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
592
593
594
595
596
597
598
599
600
601
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
602
            window_size=window_size,
603
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
604
            return_softmax=return_softmax and dropout_p > 0,
605
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
606
607
608
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
609
610
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
611
612
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
613
614
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
615
        ctx.window_size = window_size
616
        ctx.alibi_slopes = alibi_slopes
617
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
618
        return out if not return_softmax else (out, softmax_lse, S_dmask)
619
620
621

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
622
623
624
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
641
            ctx.window_size,
642
            ctx.alibi_slopes,
643
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
644
            rng_state=rng_state,
645
        )
Tri Dao's avatar
Tri Dao committed
646
647
648
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
649
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
650
651


Tri Dao's avatar
Tri Dao committed
652
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
653
654
655
656
657
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
658
    alibi_slopes=None,
659
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
660
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
661
):
Tri Dao's avatar
Tri Dao committed
662
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
663
664
665
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
666
667
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
668

Tri Dao's avatar
Tri Dao committed
669
670
671
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
672
    Arguments:
Tri Dao's avatar
Tri Dao committed
673
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
674
675
676
677
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
678
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
679
680
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
681
682
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
683
684
685
686
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
687
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
688
689
690
691
692
693
694
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
695
    return FlashAttnQKVPackedFunc.apply(
696
697
698
699
700
701
702
703
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
704
    )
Tri Dao's avatar
Tri Dao committed
705
706


Tri Dao's avatar
Tri Dao committed
707
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
708
709
710
711
712
713
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
714
    alibi_slopes=None,
715
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
716
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
717
):
Tri Dao's avatar
Tri Dao committed
718
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
719
720
721
722
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
723
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
724
725
726
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

727
728
729
730
731
732
733
734
735
736
737
738
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
739
740
741
742
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
743
    Arguments:
Tri Dao's avatar
Tri Dao committed
744
745
746
747
748
749
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
750
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
751
752
753
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
754
755
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
756
757
758
759
760
761
762
763
764
765
766
767
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
768
    return FlashAttnKVPackedFunc.apply(
769
770
771
772
773
774
775
776
777
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
778
    )
Tri Dao's avatar
Tri Dao committed
779
780


Tri Dao's avatar
Tri Dao committed
781
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
782
783
784
785
786
787
788
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
789
    alibi_slopes=None,
790
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
791
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
792
):
Tri Dao's avatar
Tri Dao committed
793
794
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
795
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
796
797
798
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

799
800
801
802
803
804
805
806
807
808
809
810
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
811
812
813
814
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
815
816
817
818
819
820
821
822
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
823
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
824
825
826
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
827
828
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
829
830
831
832
833
834
835
836
837
838
839
840
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
841
    return FlashAttnFunc.apply(
842
843
844
845
846
847
848
849
850
851
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
852
    )
Tri Dao's avatar
Tri Dao committed
853
854


Tri Dao's avatar
Tri Dao committed
855
856
857
858
859
860
861
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
862
    window_size=(-1, -1),  # -1 means infinite context window
863
    alibi_slopes=None,
864
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
865
866
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
867
868
869
870
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
871
872
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
873

Tri Dao's avatar
Tri Dao committed
874
875
876
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
877
878
879
880
881
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
882
883
884
885
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
886
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
887
888
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
889
890
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
891
892
893
894
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
895
        out: (total, nheads, headdim).
896
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
897
898
899
900
901
902
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
903
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
904
905
906
907
908
909
910
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
911
        alibi_slopes,
912
        deterministic,
Tri Dao's avatar
Tri Dao committed
913
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
914
    )
Tri Dao's avatar
Tri Dao committed
915
916


Tri Dao's avatar
Tri Dao committed
917
918
919
920
921
922
923
924
925
926
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
927
    window_size=(-1, -1),  # -1 means infinite context window
928
    alibi_slopes=None,
929
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
930
931
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
932
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
933
934
935
936
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
937
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
938
939
940
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

941
942
943
944
945
946
947
948
949
950
951
952
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
953
954
955
956
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
957
958
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
959
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
960
961
962
963
964
965
966
967
968
969
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
970
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
971
972
973
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
974
975
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
976
977
978
979
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
980
        out: (total, nheads, headdim).
981
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
982
983
984
985
986
987
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
988
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
989
990
991
992
993
994
995
996
997
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
998
        window_size,
999
        alibi_slopes,
1000
        deterministic,
Tri Dao's avatar
Tri Dao committed
1001
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1002
    )
Tri Dao's avatar
Tri Dao committed
1003

1004

Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1016
    window_size=(-1, -1),  # -1 means infinite context window
1017
    alibi_slopes=None,
1018
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1019
    return_attn_probs=False,
1020
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1021
):
Tri Dao's avatar
Tri Dao committed
1022
1023
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1024
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1025
1026
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1027

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1040
1041
1042
1043
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1044
    Arguments:
Tri Dao's avatar
Tri Dao committed
1045
1046
1047
1048
1049
1050
1051
1052
1053
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1054
1055
1056
1057
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1058
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1059
1060
1061
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1062
1063
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1064
1065
1066
1067
1068
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1069
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1070
1071
1072
1073
1074
1075
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1076
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1087
        window_size,
1088
        alibi_slopes,
1089
        deterministic,
Tri Dao's avatar
Tri Dao committed
1090
        return_attn_probs,
1091
        block_table,
Tri Dao's avatar
Tri Dao committed
1092
    )
Tri Dao's avatar
Tri Dao committed
1093
1094
1095
1096
1097
1098
1099
1100


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1101
1102
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1103
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1104
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1105
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1106
1107
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1108
    window_size=(-1, -1),  # -1 means infinite context window
1109
    rotary_interleaved=True,
1110
    alibi_slopes=None,
1111
    num_splits=0,
Tri Dao's avatar
Tri Dao committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1123
1124
1125
1126
1127
1128
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1129
1130

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1149
1150
1151
1152
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1153
1154
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1155
1156
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1157
1158
1159
1160
1161
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1162
1163
1164
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1165
1166
1167
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1168
1169
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1170
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1171
1172
1173
1174
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1175
1176
1177
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1178
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1179
1180
1181
1182
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1183
1184
1185
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1200
1201
1202
1203
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1204
1205
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1206
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1207
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1208
1209
1210
1211
1212
1213
1214
1215
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1216
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1217
        block_table,
1218
        alibi_slopes,
1219
1220
1221
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1222
1223
        window_size[0],
        window_size[1],
1224
1225
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1226
1227
    )
    return out