flash_attn_interface.py 43.8 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _get_block_size(device, head_dim, is_dropout, is_causal):
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
        return 128, 128
    if head_dim <= 64:
        return (128, 128) if not is_dropout else (128, 64)
    elif head_dim <= 96:
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
    elif head_dim <= 128:
        if is_sm8x:
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
        else:
            return 128, (64 if not is_dropout else 32)
    elif head_dim <= 160:
        if is_sm8x:
            return (128, 64) if not is_causal else (64, 64)
        else:
            return 128, 32
    elif head_dim <= 192:
        return (128, 64) if not is_dropout else (64, 64)
    elif head_dim <= 224:
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
    elif head_dim <= 256:
        return (128, 64) if is_sm80 else (64, 64)


Tri Dao's avatar
Tri Dao committed
46
47
48
def _flash_attn_forward(
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
64
    )
65
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
66
67


Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
79
    window_size,
80
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
81
82
    return_softmax,
):
Tri Dao's avatar
Tri Dao committed
83
84
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
85
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
92
        None,
93
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
100
101
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
102
103
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
104
105
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
106
    #     breakpoint()
107
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
108
109


Tri Dao's avatar
Tri Dao committed
110
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
124
    alibi_slopes,
125
    deterministic,
Tri Dao's avatar
Tri Dao committed
126
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
127
):
Tri Dao's avatar
Tri Dao committed
128
129
130
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
131
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
132
133
134
135
136
137
138
139
140
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
141
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
142
143
144
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
145
146
        window_size[0],
        window_size[1],
147
        deterministic,
Tri Dao's avatar
Tri Dao committed
148
149
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
150
151
152
153
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
171
    window_size,
172
    alibi_slopes,
173
    deterministic,
Tri Dao's avatar
Tri Dao committed
174
175
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
176
177
178
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
179
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
180
181
182
183
184
185
186
187
188
189
190
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
191
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
197
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
198
199
        window_size[0],
        window_size[1],
200
        deterministic,
Tri Dao's avatar
Tri Dao committed
201
202
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
203
    )
Tri Dao's avatar
Tri Dao committed
204
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
205
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
206
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
207
208


Tri Dao's avatar
Tri Dao committed
209
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
210
    @staticmethod
Tri Dao's avatar
Tri Dao committed
211
    def forward(
212
213
214
215
216
217
218
219
220
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
221
    ):
Tri Dao's avatar
Tri Dao committed
222
223
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
224
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
231
            window_size=window_size,
232
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
233
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
234
        )
Tri Dao's avatar
Tri Dao committed
235
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
236
237
238
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
239
        ctx.window_size = window_size
240
        ctx.alibi_slopes = alibi_slopes
241
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
242
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
243
244

    @staticmethod
Tri Dao's avatar
Tri Dao committed
245
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
246
247
248
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
249
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
255
256
257
258
259
260
261
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
262
            ctx.window_size,
263
            ctx.alibi_slopes,
264
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
265
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
266
        )
Tri Dao's avatar
Tri Dao committed
267
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
268
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
269
270
271
272


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
280
281
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
282
        alibi_slopes,
283
        deterministic,
Tri Dao's avatar
Tri Dao committed
284
285
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
286
287
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
288
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
289
290
291
292
293
294
295
296
297
298
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
299
            window_size=window_size,
300
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
301
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
302
        )
Tri Dao's avatar
Tri Dao committed
303
304
305
306
307
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
308
        ctx.window_size = window_size
309
        ctx.alibi_slopes = alibi_slopes
310
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
311
312
313
314
315
316
317
318
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
335
            ctx.window_size,
336
            ctx.alibi_slopes,
337
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
338
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
339
        )
Tri Dao's avatar
Tri Dao committed
340
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
341
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
342
343


Tri Dao's avatar
Tri Dao committed
344
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
345
    @staticmethod
Tri Dao's avatar
Tri Dao committed
346
    def forward(
347
348
349
350
351
352
353
354
355
356
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
357
    ):
Tri Dao's avatar
Tri Dao committed
358
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
359
            softmax_scale = q.shape[-1] ** (-0.5)
360
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
361
362
363
364
365
366
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
367
            window_size=window_size,
368
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
369
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
370
        )
Tri Dao's avatar
Tri Dao committed
371
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
372
373
374
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
375
        ctx.window_size = window_size
376
        ctx.alibi_slopes = alibi_slopes
377
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
378
379
380
381
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
382
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
383
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
384
385
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
386
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
387
388
389
390
391
392
393
394
395
396
397
398
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
399
            ctx.window_size,
400
            ctx.alibi_slopes,
401
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
402
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
403
        )
Tri Dao's avatar
Tri Dao committed
404
405
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
406
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
407
408


Tri Dao's avatar
Tri Dao committed
409
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
410
    @staticmethod
Tri Dao's avatar
Tri Dao committed
411
412
413
414
415
416
417
418
419
420
421
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
422
        window_size,
423
        alibi_slopes,
424
        deterministic,
Tri Dao's avatar
Tri Dao committed
425
426
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
427
428
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
429
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
430
431
432
433
434
435
436
437
438
439
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
440
            window_size=window_size,
441
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
442
443
444
445
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
446
447
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
448
449
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
450
451
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
452
        ctx.window_size = window_size
453
        ctx.alibi_slopes = alibi_slopes
454
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
455
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
456
457

    @staticmethod
Tri Dao's avatar
Tri Dao committed
458
459
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
460
461
462
463
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
480
            ctx.window_size,
481
            ctx.alibi_slopes,
482
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
483
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
484
        )
Tri Dao's avatar
Tri Dao committed
485
486
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
487
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
488
489
490
491


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
492
    def forward(
493
494
495
496
497
498
499
500
501
502
503
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
504
    ):
Tri Dao's avatar
Tri Dao committed
505
506
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
507
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
508
509
510
511
512
513
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
514
            window_size=window_size,
515
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
516
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
517
518
519
520
521
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
522
        ctx.window_size = window_size
523
        ctx.alibi_slopes = alibi_slopes
524
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
525
526
527
528
529
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
530
531
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
532
533
534
535
536
537
538
539
540
541
542
543
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
544
            ctx.window_size,
545
            ctx.alibi_slopes,
546
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
547
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
548
        )
Tri Dao's avatar
Tri Dao committed
549
550
551
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
552
        return dq, dk, dv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
553
554


Tri Dao's avatar
Tri Dao committed
555
class FlashAttnVarlenFunc(torch.autograd.Function):
556
    @staticmethod
Tri Dao's avatar
Tri Dao committed
557
558
559
560
561
562
563
564
565
566
567
568
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
569
        window_size,
570
        alibi_slopes,
571
        deterministic,
Tri Dao's avatar
Tri Dao committed
572
573
        return_softmax,
    ):
574
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
575
            softmax_scale = q.shape[-1] ** (-0.5)
576
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
577
578
579
580
581
582
583
584
585
586
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
587
            window_size=window_size,
588
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
589
590
591
592
            return_softmax=return_softmax and dropout_p > 0,
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
593
594
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
595
596
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
597
598
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
599
        ctx.window_size = window_size
600
        ctx.alibi_slopes = alibi_slopes
601
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
602
        return out if not return_softmax else (out, softmax_lse, S_dmask)
603
604
605

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
606
607
608
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
625
            ctx.window_size,
626
            ctx.alibi_slopes,
627
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
628
            rng_state=rng_state,
629
        )
Tri Dao's avatar
Tri Dao committed
630
631
632
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
633
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
634
635


Tri Dao's avatar
Tri Dao committed
636
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
637
638
639
640
641
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
642
    alibi_slopes=None,
643
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
644
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
645
):
Tri Dao's avatar
Tri Dao committed
646
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
647
648
649
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
650
651
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
652

Tri Dao's avatar
Tri Dao committed
653
654
655
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
656
    Arguments:
Tri Dao's avatar
Tri Dao committed
657
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
658
659
660
661
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
662
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
663
664
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
665
666
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
667
668
669
670
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
671
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
672
673
674
675
676
677
678
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
679
    return FlashAttnQKVPackedFunc.apply(
680
681
682
683
684
685
686
687
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
688
    )
Tri Dao's avatar
Tri Dao committed
689
690


Tri Dao's avatar
Tri Dao committed
691
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
692
693
694
695
696
697
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
698
    alibi_slopes=None,
699
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
700
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
701
):
Tri Dao's avatar
Tri Dao committed
702
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
703
704
705
706
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
707
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
708
709
710
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

711
712
713
714
715
716
717
718
719
720
721
722
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
723
724
725
726
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
727
    Arguments:
Tri Dao's avatar
Tri Dao committed
728
729
730
731
732
733
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
734
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
735
736
737
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
738
739
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
740
741
742
743
744
745
746
747
748
749
750
751
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
752
    return FlashAttnKVPackedFunc.apply(
753
754
755
756
757
758
759
760
761
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
762
    )
Tri Dao's avatar
Tri Dao committed
763
764


Tri Dao's avatar
Tri Dao committed
765
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
766
767
768
769
770
771
772
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
773
    alibi_slopes=None,
774
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
775
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
776
):
Tri Dao's avatar
Tri Dao committed
777
778
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
779
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
780
781
782
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

783
784
785
786
787
788
789
790
791
792
793
794
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
795
796
797
798
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
799
800
801
802
803
804
805
806
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
807
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
808
809
810
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
811
812
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
813
814
815
816
817
818
819
820
821
822
823
824
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
825
    return FlashAttnFunc.apply(
826
827
828
829
830
831
832
833
834
835
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
836
    )
Tri Dao's avatar
Tri Dao committed
837
838


Tri Dao's avatar
Tri Dao committed
839
840
841
842
843
844
845
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
846
    window_size=(-1, -1),  # -1 means infinite context window
847
    alibi_slopes=None,
848
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
849
850
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
851
852
853
854
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
855
856
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
857

Tri Dao's avatar
Tri Dao committed
858
859
860
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
861
862
863
864
865
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
866
867
868
869
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
870
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
871
872
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
873
874
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
875
876
877
878
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
879
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
880
881
882
883
884
885
886
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
887
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
888
889
890
891
892
893
894
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
895
        alibi_slopes,
896
        deterministic,
Tri Dao's avatar
Tri Dao committed
897
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
898
    )
Tri Dao's avatar
Tri Dao committed
899
900


Tri Dao's avatar
Tri Dao committed
901
902
903
904
905
906
907
908
909
910
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
911
    window_size=(-1, -1),  # -1 means infinite context window
912
    alibi_slopes=None,
913
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
914
915
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
916
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
917
918
919
920
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
921
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
922
923
924
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

925
926
927
928
929
930
931
932
933
934
935
936
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
937
938
939
940
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
941
942
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
943
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
944
945
946
947
948
949
950
951
952
953
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
954
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
955
956
957
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
958
959
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
960
961
962
963
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
964
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
965
966
967
968
969
970
971
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
972
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
973
974
975
976
977
978
979
980
981
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
982
        window_size,
983
        alibi_slopes,
984
        deterministic,
Tri Dao's avatar
Tri Dao committed
985
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
986
    )
Tri Dao's avatar
Tri Dao committed
987

988

Tri Dao's avatar
Tri Dao committed
989
990
991
992
993
994
995
996
997
998
999
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1000
    window_size=(-1, -1),  # -1 means infinite context window
1001
    alibi_slopes=None,
1002
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1003
1004
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
1005
1006
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1007
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1008
1009
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1010

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1023
1024
1025
1026
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1027
    Arguments:
Tri Dao's avatar
Tri Dao committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1037
1038
1039
1040
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1041
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1042
1043
1044
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1045
1046
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1059
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1070
        window_size,
1071
        alibi_slopes,
1072
        deterministic,
Tri Dao's avatar
Tri Dao committed
1073
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1074
    )
Tri Dao's avatar
Tri Dao committed
1075
1076
1077
1078
1079
1080
1081
1082


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1083
1084
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1085
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1086
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1087
1088
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1089
    window_size=(-1, -1),  # -1 means infinite context window
1090
    rotary_interleaved=True,
1091
    alibi_slopes=None,
1092
    num_splits=0,
Tri Dao's avatar
Tri Dao committed
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1104
1105
1106
1107
1108
1109
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1110
1111

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1130
1131
1132
1133
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1134
1135
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1136
1137
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
1138
1139
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
1140
1141
1142
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1143
1144
1145
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1146
1147
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
1148
1149
1150
1151
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1152
1153
1154
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1155
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1156
1157
1158
1159
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1160
1161
1162
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1177
1178
1179
1180
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1181
1182
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1183
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1184
1185
1186
1187
1188
1189
1190
1191
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1192
        cache_batch_idx,
1193
        alibi_slopes,
1194
1195
1196
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1197
1198
        window_size[0],
        window_size[1],
1199
1200
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1201
1202
    )
    return out