flash_attn_interface.py 46.5 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
Woosuk Kwon's avatar
Woosuk Kwon committed
10
import vllm_flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

14
15
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
16

Tri Dao's avatar
Tri Dao committed
17
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
24
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
25
        return 128
Tri Dao's avatar
Tri Dao committed
26
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
27
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
28
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
29
        return 64
Tri Dao's avatar
Tri Dao committed
30
31
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
33
        else:
Tri Dao's avatar
Tri Dao committed
34
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
35
36
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
37
            return 64
Tri Dao's avatar
Tri Dao committed
38
        else:
Tri Dao's avatar
Tri Dao committed
39
            return 32
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
45
        return 64
Tri Dao's avatar
Tri Dao committed
46
47


Tri Dao's avatar
Tri Dao committed
48
def _flash_attn_forward(
49
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
50
):
Tri Dao's avatar
Tri Dao committed
51
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
52
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
53
54
55
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
56
        out,
57
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
63
        softcap,
Tri Dao's avatar
Tri Dao committed
64
65
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
66
    )
67
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
68
69


Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
77
78
79
80
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
81
    window_size,
82
    softcap,
83
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
84
    return_softmax,
85
    block_table,
Antoni Baum's avatar
Antoni Baum committed
86
87
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
88
):
Tri Dao's avatar
Tri Dao committed
89
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
90
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
91
92
93
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
94
        out,
Tri Dao's avatar
Tri Dao committed
95
96
        cu_seqlens_q,
        cu_seqlens_k,
97
        None,
98
        block_table,
99
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
106
107
        window_size[0],
        window_size[1],
108
        softcap,
Tri Dao's avatar
Tri Dao committed
109
110
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
111
112
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
113
    #     breakpoint()
114
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
115
116


Tri Dao's avatar
Tri Dao committed
117
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
128
129
130
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
131
    softcap,
132
    alibi_slopes,
133
    deterministic,
Tri Dao's avatar
Tri Dao committed
134
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
135
):
Tri Dao's avatar
Tri Dao committed
136
137
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
138
139
140
141
142
143
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
144
145
146
147
148
149
150
151
152
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
153
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
154
155
156
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
157
158
        window_size[0],
        window_size[1],
159
        softcap,
160
        deterministic,
Tri Dao's avatar
Tri Dao committed
161
162
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
163
164
165
166
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
184
    window_size,
185
    softcap,
186
    alibi_slopes,
187
    deterministic,
Tri Dao's avatar
Tri Dao committed
188
189
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
190
191
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
192
193
194
195
196
197
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
206
207
208
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
209
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
216
217
        window_size[0],
        window_size[1],
218
        softcap,
219
        deterministic,
Tri Dao's avatar
Tri Dao committed
220
221
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
222
    )
Tri Dao's avatar
Tri Dao committed
223
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
224
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
225
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
226
227


Tri Dao's avatar
Tri Dao committed
228
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
229
    @staticmethod
Tri Dao's avatar
Tri Dao committed
230
    def forward(
231
232
233
234
235
236
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
237
        softcap,
238
239
240
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
241
242
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
243
    ):
Tri Dao's avatar
Tri Dao committed
244
245
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
246
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
253
            window_size=window_size,
254
            softcap=softcap,
255
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
256
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
257
            out=out,
Tri Dao's avatar
Tri Dao committed
258
        )
Tri Dao's avatar
Tri Dao committed
259
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
260
261
262
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
263
        ctx.window_size = window_size
264
        ctx.softcap = softcap
265
        ctx.alibi_slopes = alibi_slopes
266
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
267
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
268
269

    @staticmethod
Tri Dao's avatar
Tri Dao committed
270
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
271
272
273
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
274
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
284
285
286
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
287
            ctx.window_size,
288
            ctx.softcap,
289
            ctx.alibi_slopes,
290
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
291
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
292
        )
Tri Dao's avatar
Tri Dao committed
293
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
294
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
295
296
297
298


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
304
305
306
307
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
308
        softcap,
309
        alibi_slopes,
310
        deterministic,
Tri Dao's avatar
Tri Dao committed
311
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
312
313
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
314
    ):
Tri Dao's avatar
Tri Dao committed
315
316
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
317
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
318
319
320
321
322
323
324
325
326
327
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
328
            window_size=window_size,
329
            softcap=softcap,
330
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
331
            return_softmax=return_softmax and dropout_p > 0,
332
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
333
            out=out,
Tri Dao's avatar
Tri Dao committed
334
        )
Tri Dao's avatar
Tri Dao committed
335
336
337
338
339
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
340
        ctx.window_size = window_size
341
        ctx.softcap = softcap
342
        ctx.alibi_slopes = alibi_slopes
343
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
344
345
346
347
348
349
350
351
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
368
            ctx.window_size,
369
            ctx.softcap,
370
            ctx.alibi_slopes,
371
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
372
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
373
        )
Tri Dao's avatar
Tri Dao committed
374
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
375
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
376
377


Tri Dao's avatar
Tri Dao committed
378
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
379
    @staticmethod
Tri Dao's avatar
Tri Dao committed
380
    def forward(
381
382
383
384
385
386
387
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
388
        softcap,
389
390
391
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
392
        out=None,
Tri Dao's avatar
Tri Dao committed
393
    ):
Tri Dao's avatar
Tri Dao committed
394
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
395
            softmax_scale = q.shape[-1] ** (-0.5)
396
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
397
398
399
400
401
402
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
403
            window_size=window_size,
404
            softcap=softcap,
405
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
406
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
407
            out=out,
Tri Dao's avatar
Tri Dao committed
408
        )
Tri Dao's avatar
Tri Dao committed
409
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
410
411
412
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
413
        ctx.window_size = window_size
414
        ctx.softcap = softcap
415
        ctx.alibi_slopes = alibi_slopes
416
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
417
418
419
420
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
421
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
422
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
423
424
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
425
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
426
427
428
429
430
431
432
433
434
435
436
437
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
438
            ctx.window_size,
439
            ctx.softcap,
440
            ctx.alibi_slopes,
441
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
442
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
443
        )
Tri Dao's avatar
Tri Dao committed
444
445
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
446
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
447
448


Tri Dao's avatar
Tri Dao committed
449
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
450
    @staticmethod
Tri Dao's avatar
Tri Dao committed
451
452
453
454
455
456
457
458
459
460
461
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
462
        window_size,
463
        softcap,
464
        alibi_slopes,
465
        deterministic,
Tri Dao's avatar
Tri Dao committed
466
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
467
        out=None,
Tri Dao's avatar
Tri Dao committed
468
    ):
Tri Dao's avatar
Tri Dao committed
469
470
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
471
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
472
473
474
475
476
477
478
479
480
481
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
482
            window_size=window_size,
483
            softcap=softcap,
484
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
485
            return_softmax=return_softmax and dropout_p > 0,
486
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
487
            out=out,
Tri Dao's avatar
Tri Dao committed
488
489
490
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
491
492
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
493
494
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
495
496
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
497
        ctx.window_size = window_size
498
        ctx.softcap = softcap
499
        ctx.alibi_slopes = alibi_slopes
500
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
501
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
502
503

    @staticmethod
Tri Dao's avatar
Tri Dao committed
504
505
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
506
507
508
509
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
526
            ctx.window_size,
527
            ctx.softcap,
528
            ctx.alibi_slopes,
529
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
530
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
531
        )
Tri Dao's avatar
Tri Dao committed
532
533
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
534
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
535
536
537
538


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
539
    def forward(
540
541
542
543
544
545
546
547
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
548
        softcap,
549
550
551
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
552
        out=None,
Tri Dao's avatar
Tri Dao committed
553
    ):
Tri Dao's avatar
Tri Dao committed
554
555
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
556
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
557
558
559
560
561
562
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
563
            window_size=window_size,
564
            softcap=softcap,
565
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
566
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
567
            out=out,
Tri Dao's avatar
Tri Dao committed
568
569
570
571
572
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
573
        ctx.window_size = window_size
574
        ctx.softcap = softcap
575
        ctx.alibi_slopes = alibi_slopes
576
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
577
578
579
580
581
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
582
583
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
584
585
586
587
588
589
590
591
592
593
594
595
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
596
            ctx.window_size,
597
            ctx.softcap,
598
            ctx.alibi_slopes,
599
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
600
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
601
        )
Tri Dao's avatar
Tri Dao committed
602
603
604
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
605
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
606
607


Tri Dao's avatar
Tri Dao committed
608
class FlashAttnVarlenFunc(torch.autograd.Function):
609
    @staticmethod
Tri Dao's avatar
Tri Dao committed
610
611
612
613
614
615
616
617
618
619
620
621
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
622
        window_size,
623
        softcap,
624
        alibi_slopes,
625
        deterministic,
Tri Dao's avatar
Tri Dao committed
626
        return_softmax,
627
        block_table,
Antoni Baum's avatar
Antoni Baum committed
628
        out=None,
Tri Dao's avatar
Tri Dao committed
629
    ):
630
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
631
            softmax_scale = q.shape[-1] ** (-0.5)
632
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
633
634
635
636
637
638
639
640
641
642
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
643
            window_size=window_size,
644
            softcap=softcap,
645
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
646
            return_softmax=return_softmax and dropout_p > 0,
647
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
648
            out=out,
Tri Dao's avatar
Tri Dao committed
649
650
651
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
652
653
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
654
655
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
656
657
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
658
        ctx.window_size = window_size
659
        ctx.softcap = softcap
660
        ctx.alibi_slopes = alibi_slopes
661
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
662
        return out if not return_softmax else (out, softmax_lse, S_dmask)
663
664
665

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
666
667
668
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
685
            ctx.window_size,
686
            ctx.softcap,
687
            ctx.alibi_slopes,
688
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
689
            rng_state=rng_state,
690
        )
Tri Dao's avatar
Tri Dao committed
691
692
693
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
694
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
695
696


Tri Dao's avatar
Tri Dao committed
697
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
698
699
700
701
702
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
703
    softcap=0.0,  # <=0.0 means deactivate
704
    alibi_slopes=None,
705
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
706
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
707
708
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
709
):
Tri Dao's avatar
Tri Dao committed
710
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
711
712
713
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
714
715
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
716

Tri Dao's avatar
Tri Dao committed
717
718
719
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
720
    Arguments:
Tri Dao's avatar
Tri Dao committed
721
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
722
723
724
725
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
726
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
727
        softcap: float. Anything > 0 activates softcapping attention.
728
729
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
730
731
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
732
733
734
735
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
736
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
737
738
739
740
741
742
743
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
744
    return FlashAttnQKVPackedFunc.apply(
745
746
747
748
749
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
750
        softcap,
751
752
753
        alibi_slopes,
        deterministic,
        return_attn_probs,
754
        out,
Tri Dao's avatar
Tri Dao committed
755
    )
Tri Dao's avatar
Tri Dao committed
756
757


Tri Dao's avatar
Tri Dao committed
758
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
759
760
761
762
763
764
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
765
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
766
    alibi_slopes=None,
767
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
768
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
769
770
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
771
):
Tri Dao's avatar
Tri Dao committed
772
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
773
774
775
776
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
777
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
778
779
780
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

781
782
783
784
785
786
787
788
789
790
791
792
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
793
794
795
796
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
797
    Arguments:
Tri Dao's avatar
Tri Dao committed
798
799
800
801
802
803
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
804
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
805
        softcap: float. Anything > 0 activates softcapping attention.
806
807
808
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
809
810
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
811
812
813
814
815
816
817
818
819
820
821
822
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
823
    return FlashAttnKVPackedFunc.apply(
824
825
826
827
828
829
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
830
        softcap,
831
832
833
        alibi_slopes,
        deterministic,
        return_attn_probs,
834
        out,
Tri Dao's avatar
Tri Dao committed
835
    )
Tri Dao's avatar
Tri Dao committed
836
837


Tri Dao's avatar
Tri Dao committed
838
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
839
840
841
842
843
844
845
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
846
    softcap=0.0, # 0.0 means deactivated
847
    alibi_slopes=None,
848
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
849
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
850
851
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
852
):
Tri Dao's avatar
Tri Dao committed
853
854
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
855
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
856
857
858
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

859
860
861
862
863
864
865
866
867
868
869
870
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
871
872
873
874
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
875
876
877
878
879
880
881
882
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
883
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
884
885
886
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
887
888
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
889
890
891
892
893
894
895
896
897
898
899
900
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
901
    return FlashAttnFunc.apply(
902
903
904
905
906
907
908
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
909
        softcap,
910
911
912
        alibi_slopes,
        deterministic,
        return_attn_probs,
913
        out,
Tri Dao's avatar
Tri Dao committed
914
    )
Tri Dao's avatar
Tri Dao committed
915
916


Tri Dao's avatar
Tri Dao committed
917
918
919
920
921
922
923
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
924
    window_size=(-1, -1),  # -1 means infinite context window
925
    softcap=0.0, # 0.0 means deactivated
926
    alibi_slopes=None,
927
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
928
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
929
930
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
931
):
Tri Dao's avatar
Tri Dao committed
932
933
934
935
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
936
937
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
938

Tri Dao's avatar
Tri Dao committed
939
940
941
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
942
943
944
945
946
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
947
948
949
950
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
951
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
952
        softcap: float. Anything > 0 activates softcapping attention.
953
954
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
955
956
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
957
958
959
960
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
961
        out: (total, nheads, headdim).
962
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
963
964
965
966
967
968
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
969
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
970
971
972
973
974
975
976
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
977
        softcap,
978
        alibi_slopes,
979
        deterministic,
Tri Dao's avatar
Tri Dao committed
980
        return_attn_probs,
981
        out,
Tri Dao's avatar
Tri Dao committed
982
    )
Tri Dao's avatar
Tri Dao committed
983
984


Tri Dao's avatar
Tri Dao committed
985
986
987
988
989
990
991
992
993
994
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
995
    window_size=(-1, -1),  # -1 means infinite context window
996
    softcap=0.0, # 0.0 means deactivated
997
    alibi_slopes=None,
998
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
999
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
1000
1001
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1002
):
Tri Dao's avatar
Tri Dao committed
1003
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
1004
1005
1006
1007
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1008
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1009
1010
1011
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1024
1025
1026
1027
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1028
1029
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1030
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1041
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1042
        softcap: float. Anything > 0 activates softcapping attention.
1043
1044
1045
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1046
1047
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1048
1049
1050
1051
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1052
        out: (total, nheads, headdim).
1053
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1054
1055
1056
1057
1058
1059
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1060
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1061
1062
1063
1064
1065
1066
1067
1068
1069
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1070
        window_size,
1071
        softcap,
1072
        alibi_slopes,
1073
        deterministic,
Tri Dao's avatar
Tri Dao committed
1074
        return_attn_probs,
1075
        out,
Tri Dao's avatar
Tri Dao committed
1076
    )
Tri Dao's avatar
Tri Dao committed
1077

1078

Tri Dao's avatar
Tri Dao committed
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1090
    window_size=(-1, -1),  # -1 means infinite context window
1091
    softcap=0.0, # 0.0 means deactivated
1092
    alibi_slopes=None,
1093
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1094
    return_attn_probs=False,
1095
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1096
1097
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1098
):
Tri Dao's avatar
Tri Dao committed
1099
1100
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1101
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1102
1103
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1117
1118
1119
1120
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1121
    Arguments:
Tri Dao's avatar
Tri Dao committed
1122
1123
1124
1125
1126
1127
1128
1129
1130
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1131
1132
1133
1134
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1135
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1136
        softcap: float. Anything > 0 activates softcapping attention.
1137
1138
1139
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1140
1141
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1142
1143
1144
1145
1146
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1147
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1148
1149
1150
1151
1152
1153
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1154
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1165
        window_size,
1166
        softcap,
1167
        alibi_slopes,
1168
        deterministic,
Tri Dao's avatar
Tri Dao committed
1169
        return_attn_probs,
1170
        block_table,
1171
        out,
Tri Dao's avatar
Tri Dao committed
1172
    )
Tri Dao's avatar
Tri Dao committed
1173
1174
1175
1176
1177
1178
1179
1180


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1181
1182
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1183
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1184
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1185
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1186
1187
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1188
    window_size=(-1, -1),  # -1 means infinite context window
1189
    softcap=0.0, # 0.0 means deactivated
1190
    rotary_interleaved=True,
1191
    alibi_slopes=None,
1192
    num_splits=0,
1193
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
1194
1195
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1207
1208
1209
1210
1211
1212
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1213
1214

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1233
1234
1235
1236
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1237
1238
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1239
1240
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1241
1242
1243
1244
1245
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1246
1247
1248
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1249
1250
1251
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1252
1253
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1254
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1255
1256
1257
1258
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1259
1260
1261
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1262
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1263
        softcap: float. Anything > 0 activates softcapping attention.
1264
1265
1266
1267
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1268
1269
1270
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1271
1272
1273
1274
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1275
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1276
1277
1278

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1279
1280
1281
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1282
1283
1284
1285
1286
1287
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1288
1289
1290
1291
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1292
1293
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1294
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1295
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1296
1297
1298
1299
1300
1301
1302
1303
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1304
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1305
        block_table,
1306
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1307
        out,
1308
1309
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1310
1311
        window_size[0],
        window_size[1],
1312
        softcap,
1313
1314
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1315
    )
1316
    return (out, softmax_lse) if return_softmax_lse else out