flash_attn_interface.py 46 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
10
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(
Nicolas Patry's avatar
Nicolas Patry committed
47
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
Tri Dao's avatar
Tri Dao committed
48
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        q,
        k,
        v,
        None,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
62
        softcap,
Tri Dao's avatar
Tri Dao committed
63
64
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
65
    )
66
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
67
68


Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
80
    window_size,
81
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
82
    return_softmax,
83
    block_table,
Tri Dao's avatar
Tri Dao committed
84
):
Tri Dao's avatar
Tri Dao committed
85
86
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
87
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
93
        q,
        k,
        v,
        None,
        cu_seqlens_q,
        cu_seqlens_k,
94
        None,
95
        block_table,
96
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
97
98
99
100
101
102
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
103
104
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
105
106
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
107
108
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
109
    #     breakpoint()
110
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
111
112


Tri Dao's avatar
Tri Dao committed
113
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
114
115
116
117
118
119
120
121
122
123
124
125
126
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
127
    softcap,
128
    alibi_slopes,
129
    deterministic,
Tri Dao's avatar
Tri Dao committed
130
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
131
):
Tri Dao's avatar
Tri Dao committed
132
133
134
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
135
136
137
138
139
140
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
150
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
151
152
153
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
154
155
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
156
        softcap,
157
        deterministic,
Tri Dao's avatar
Tri Dao committed
158
159
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
160
161
162
163
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
181
    window_size,
Nicolas Patry's avatar
Nicolas Patry committed
182
    softcap,
183
    alibi_slopes,
184
    deterministic,
Tri Dao's avatar
Tri Dao committed
185
186
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
187
188
189
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
190
191
192
193
194
195
    (
        dq,
        dk,
        dv,
        softmax_d,
    ) = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
196
197
198
199
200
201
202
203
204
205
206
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
207
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
208
209
210
211
212
213
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
214
215
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
216
        softcap,
217
        deterministic,
Tri Dao's avatar
Tri Dao committed
218
219
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
220
    )
Tri Dao's avatar
Tri Dao committed
221
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
222
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
223
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
224
225


Tri Dao's avatar
Tri Dao committed
226
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
227
    @staticmethod
Tri Dao's avatar
Tri Dao committed
228
    def forward(
229
230
231
232
233
234
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
235
        softcap,
236
237
238
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
239
    ):
Tri Dao's avatar
Tri Dao committed
240
241
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
242
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
243
244
245
246
247
248
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
249
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
250
            softcap=softcap,
251
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
252
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
253
        )
Tri Dao's avatar
Tri Dao committed
254
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
255
256
257
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
258
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
259
        ctx.softcap = softcap
260
        ctx.alibi_slopes = alibi_slopes
261
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
262
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
263
264

    @staticmethod
Tri Dao's avatar
Tri Dao committed
265
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
266
267
268
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
269
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
270
271
272
273
274
275
276
277
278
279
280
281
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
282
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
283
            ctx.softcap,
284
            ctx.alibi_slopes,
285
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
286
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
287
        )
Tri Dao's avatar
Tri Dao committed
288
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
289
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
290
291
292
293


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
294
295
296
297
298
299
300
301
302
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
303
        alibi_slopes,
304
        deterministic,
Tri Dao's avatar
Tri Dao committed
305
306
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
307
308
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
309
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
310
311
312
313
314
315
316
317
318
319
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
320
            window_size=window_size,
321
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
322
            return_softmax=return_softmax and dropout_p > 0,
323
            block_table=None,
Tri Dao's avatar
Tri Dao committed
324
        )
Tri Dao's avatar
Tri Dao committed
325
326
327
328
329
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
330
        ctx.window_size = window_size
331
        ctx.alibi_slopes = alibi_slopes
332
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
333
334
335
336
337
338
339
340
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
357
            ctx.window_size,
358
            ctx.alibi_slopes,
359
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
360
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
361
        )
Tri Dao's avatar
Tri Dao committed
362
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
363
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
364
365


Tri Dao's avatar
Tri Dao committed
366
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
367
    @staticmethod
Tri Dao's avatar
Tri Dao committed
368
    def forward(
369
370
371
372
373
374
375
376
377
378
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
379
    ):
Tri Dao's avatar
Tri Dao committed
380
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
381
            softmax_scale = q.shape[-1] ** (-0.5)
382
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
383
384
385
386
387
388
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
389
            window_size=window_size,
390
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
391
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
392
        )
Tri Dao's avatar
Tri Dao committed
393
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
394
395
396
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
397
        ctx.window_size = window_size
398
        ctx.alibi_slopes = alibi_slopes
399
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
400
401
402
403
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
404
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
405
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
406
407
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
408
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
409
410
411
412
413
414
415
416
417
418
419
420
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
421
            ctx.window_size,
422
            ctx.alibi_slopes,
423
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
424
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
425
        )
Tri Dao's avatar
Tri Dao committed
426
427
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
428
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
429
430


Tri Dao's avatar
Tri Dao committed
431
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
432
    @staticmethod
Tri Dao's avatar
Tri Dao committed
433
434
435
436
437
438
439
440
441
442
443
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
444
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
445
        softcap,
446
        alibi_slopes,
447
        deterministic,
Tri Dao's avatar
Tri Dao committed
448
449
        return_softmax,
    ):
Tri Dao's avatar
Tri Dao committed
450
451
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
452
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
453
454
455
456
457
458
459
460
461
462
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
463
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
464
            softcap=softcap,
465
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
466
            return_softmax=return_softmax and dropout_p > 0,
467
            block_table=None,
Tri Dao's avatar
Tri Dao committed
468
469
470
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
471
472
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
473
474
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
475
476
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
477
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
478
        ctx.softcap = softcap
479
        ctx.alibi_slopes = alibi_slopes
480
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
481
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
482
483

    @staticmethod
Tri Dao's avatar
Tri Dao committed
484
485
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
486
487
488
489
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
506
            ctx.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
507
            ctx.softcap,
508
            ctx.alibi_slopes,
509
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
510
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
511
        )
Tri Dao's avatar
Tri Dao committed
512
513
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
514
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
515
516
517
518


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
519
    def forward(
520
521
522
523
524
525
526
527
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
528
        softcap,
529
530
531
        alibi_slopes,
        deterministic,
        return_softmax,
Tri Dao's avatar
Tri Dao committed
532
    ):
Tri Dao's avatar
Tri Dao committed
533
534
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
535
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
536
537
538
539
540
541
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
542
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
543
            softcap=softcap,
544
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
545
            return_softmax=return_softmax and dropout_p > 0,
Tri Dao's avatar
Tri Dao committed
546
547
548
549
550
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
551
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
552
        ctx.softcap = softcap
553
        ctx.alibi_slopes = alibi_slopes
554
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
555
556
557
558
559
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
560
561
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
562
563
564
565
566
567
568
569
570
571
572
573
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
574
            ctx.window_size,
575
            ctx.softcap,
576
            ctx.alibi_slopes,
577
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
578
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
579
        )
Tri Dao's avatar
Tri Dao committed
580
581
582
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
583
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
584
585


Tri Dao's avatar
Tri Dao committed
586
class FlashAttnVarlenFunc(torch.autograd.Function):
587
    @staticmethod
Tri Dao's avatar
Tri Dao committed
588
589
590
591
592
593
594
595
596
597
598
599
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
600
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
601
        softcap,
602
        alibi_slopes,
603
        deterministic,
Tri Dao's avatar
Tri Dao committed
604
        return_softmax,
605
        block_table,
Tri Dao's avatar
Tri Dao committed
606
    ):
607
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
608
            softmax_scale = q.shape[-1] ** (-0.5)
609
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
610
611
612
613
614
615
616
617
618
619
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
620
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
621
            softcap=softcap,
622
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
623
            return_softmax=return_softmax and dropout_p > 0,
624
            block_table=block_table,
Tri Dao's avatar
Tri Dao committed
625
626
627
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
628
629
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
630
631
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
632
633
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
634
        ctx.window_size = window_size
Nicolas Patry's avatar
Nicolas Patry committed
635
        ctx.softcap = softcap
636
        ctx.alibi_slopes = alibi_slopes
637
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
638
        return out if not return_softmax else (out, softmax_lse, S_dmask)
639
640
641

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
642
643
644
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
661
            ctx.window_size,
662
            ctx.softcap,
663
            ctx.alibi_slopes,
664
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
665
            rng_state=rng_state,
666
        )
Tri Dao's avatar
Tri Dao committed
667
668
669
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
670
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
671
672


Tri Dao's avatar
Tri Dao committed
673
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
674
675
676
677
678
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
679
    softcap=0.0,  # <=0.0 means deactivate
680
    alibi_slopes=None,
681
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
682
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
683
):
Tri Dao's avatar
Tri Dao committed
684
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
685
686
687
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
688
689
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
690

Tri Dao's avatar
Tri Dao committed
691
692
693
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
694
    Arguments:
Tri Dao's avatar
Tri Dao committed
695
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
696
697
698
699
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
700
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
701
        softcap: float. Anything > 0 activates softcapping attention.
702
703
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
704
705
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
706
707
708
709
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
710
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
711
712
713
714
715
716
717
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
718
    return FlashAttnQKVPackedFunc.apply(
719
720
721
722
723
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
724
        softcapping,
725
726
727
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
728
    )
Tri Dao's avatar
Tri Dao committed
729
730


Tri Dao's avatar
Tri Dao committed
731
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
732
733
734
735
736
737
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
738
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
739
    alibi_slopes=None,
740
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
741
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
742
):
Tri Dao's avatar
Tri Dao committed
743
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
744
745
746
747
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
748
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
749
750
751
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

752
753
754
755
756
757
758
759
760
761
762
763
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
764
765
766
767
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
768
    Arguments:
Tri Dao's avatar
Tri Dao committed
769
770
771
772
773
774
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
775
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
776
        softcap: float. Anything > 0 activates softcapping attention.
777
778
779
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
780
781
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
782
783
784
785
786
787
788
789
790
791
792
793
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
794
    return FlashAttnKVPackedFunc.apply(
795
796
797
798
799
800
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
801
        softcap,
802
803
804
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
805
    )
Tri Dao's avatar
Tri Dao committed
806
807


Tri Dao's avatar
Tri Dao committed
808
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
809
810
811
812
813
814
815
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
816
    softcap=0.0, # 0.0 means deactivated
817
    alibi_slopes=None,
818
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
819
    return_attn_probs=False,
Tri Dao's avatar
Tri Dao committed
820
):
Tri Dao's avatar
Tri Dao committed
821
822
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
823
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
824
825
826
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

827
828
829
830
831
832
833
834
835
836
837
838
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
839
840
841
842
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
843
844
845
846
847
848
849
850
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
851
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
852
853
854
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
855
856
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
857
858
859
860
861
862
863
864
865
866
867
868
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
869
    return FlashAttnFunc.apply(
870
871
872
873
874
875
876
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
877
        softcap,
878
879
880
        alibi_slopes,
        deterministic,
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
881
    )
Tri Dao's avatar
Tri Dao committed
882
883


Tri Dao's avatar
Tri Dao committed
884
885
886
887
888
889
890
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
891
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
892
    softcap=0.0, # 0.0 means deactivated
893
    alibi_slopes=None,
894
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
895
896
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
897
898
899
900
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
901
902
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
903

Tri Dao's avatar
Tri Dao committed
904
905
906
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
907
908
909
910
911
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
912
913
914
915
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
916
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
917
        softcap: float. Anything > 0 activates softcapping attention.
918
919
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
920
921
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
922
923
924
925
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
926
        out: (total, nheads, headdim).
927
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
928
929
930
931
932
933
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
934
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
935
936
937
938
939
940
941
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
942
        softcap,
943
        alibi_slopes,
944
        deterministic,
Tri Dao's avatar
Tri Dao committed
945
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
946
    )
Tri Dao's avatar
Tri Dao committed
947
948


Tri Dao's avatar
Tri Dao committed
949
950
951
952
953
954
955
956
957
958
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
959
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
960
    softcap=0.0, # 0.0 means deactivated
961
    alibi_slopes=None,
962
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
963
964
    return_attn_probs=False,
):
Tri Dao's avatar
Tri Dao committed
965
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
966
967
968
969
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
970
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
971
972
973
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

974
975
976
977
978
979
980
981
982
983
984
985
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
986
987
988
989
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
990
991
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
992
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
993
994
995
996
997
998
999
1000
1001
1002
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1003
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1004
        softcap: float. Anything > 0 activates softcapping attention.
1005
1006
1007
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1008
1009
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1010
1011
1012
1013
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1014
        out: (total, nheads, headdim).
1015
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1016
1017
1018
1019
1020
1021
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1022
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1023
1024
1025
1026
1027
1028
1029
1030
1031
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1032
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1033
        softcap,
1034
        alibi_slopes,
1035
        deterministic,
Tri Dao's avatar
Tri Dao committed
1036
        return_attn_probs,
Tri Dao's avatar
Tri Dao committed
1037
    )
Tri Dao's avatar
Tri Dao committed
1038

1039

Tri Dao's avatar
Tri Dao committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1051
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1052
    softcap=0.0, # 0.0 means deactivated
1053
    alibi_slopes=None,
1054
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1055
    return_attn_probs=False,
1056
    block_table=None,
Tri Dao's avatar
Tri Dao committed
1057
):
Tri Dao's avatar
Tri Dao committed
1058
1059
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1060
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1061
1062
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1063

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1076
1077
1078
1079
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1080
    Arguments:
Tri Dao's avatar
Tri Dao committed
1081
1082
1083
1084
1085
1086
1087
1088
1089
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1090
1091
1092
1093
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1094
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1095
        softcap: float. Anything > 0 activates softcapping attention.
1096
1097
1098
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1099
1100
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1101
1102
1103
1104
1105
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1106
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1107
1108
1109
1110
1111
1112
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1113
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1124
        window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1125
        softcap,
1126
        alibi_slopes,
1127
        deterministic,
Tri Dao's avatar
Tri Dao committed
1128
        return_attn_probs,
1129
        block_table,
Tri Dao's avatar
Tri Dao committed
1130
    )
Tri Dao's avatar
Tri Dao committed
1131
1132
1133
1134
1135
1136
1137
1138


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1139
1140
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1141
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1142
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1143
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1144
1145
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1146
    window_size=(-1, -1),  # -1 means infinite context window
Nicolas Patry's avatar
Nicolas Patry committed
1147
    softcap=0.0, # 0.0 means deactivated
1148
    rotary_interleaved=True,
1149
    alibi_slopes=None,
1150
    num_splits=0,
1151
    return_softmax_lse=False,
Tri Dao's avatar
Tri Dao committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1163
1164
1165
1166
1167
1168
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1169
1170

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1189
1190
1191
1192
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1193
1194
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1195
1196
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1197
1198
1199
1200
1201
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1202
1203
1204
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1205
1206
1207
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1208
1209
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1210
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1211
1212
1213
1214
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1215
1216
1217
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1218
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Nicolas Patry's avatar
Nicolas Patry committed
1219
        softcap: float. Anything > 0 activates softcapping attention.
1220
1221
1222
1223
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1224
1225
1226
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1227
1228
1229
1230
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1231
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1232
1233
1234

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1235
1236
1237
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1238
1239
1240
1241
1242
1243
1244
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1245
1246
1247
1248
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1249
1250
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1251
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1252
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1253
1254
1255
1256
1257
1258
1259
1260
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1261
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1262
        block_table,
1263
        alibi_slopes,
1264
1265
1266
        None,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1267
1268
        window_size[0],
        window_size[1],
Nicolas Patry's avatar
Nicolas Patry committed
1269
        softcap,
1270
1271
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1272
    )
1273
    return (out, softmax_lse) if return_softmax_lse else out