flash_attn_interface.py 44.8 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
Woosuk Kwon's avatar
Woosuk Kwon committed
10
import vllm_flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
23
        return 128
Tri Dao's avatar
Tri Dao committed
24
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
25
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
26
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
27
        return 64
Tri Dao's avatar
Tri Dao committed
28
29
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
30
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
31
        else:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
33
34
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
35
            return 64
Tri Dao's avatar
Tri Dao committed
36
        else:
Tri Dao's avatar
Tri Dao committed
37
            return 32
Tri Dao's avatar
Tri Dao committed
38
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
39
        return 64
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
45


Tri Dao's avatar
Tri Dao committed
46
def _flash_attn_forward(
Antoni Baum's avatar
Antoni Baum committed
47
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
48
):
Tri Dao's avatar
Tri Dao committed
49
50
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
51
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
52
53
54
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
55
        out,
56
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
64
    )
65
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
66
67


Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
79
    window_size,
80
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
81
    return_softmax,
82
    block_table,
Antoni Baum's avatar
Antoni Baum committed
83
84
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
85
):
Tri Dao's avatar
Tri Dao committed
86
87
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
88
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
89
90
91
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
92
        out,
Tri Dao's avatar
Tri Dao committed
93
94
        cu_seqlens_q,
        cu_seqlens_k,
95
        None,
96
        block_table,
97
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
104
105
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
106
107
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
108
109
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
110
    #     breakpoint()
111
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
112
113


Tri Dao's avatar
Tri Dao committed
114
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
115
116
117
118
119
120
121
122
123
124
125
126
127
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
128
    alibi_slopes,
129
    deterministic,
Tri Dao's avatar
Tri Dao committed
130
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
131
):
Tri Dao's avatar
Tri Dao committed
132
133
134
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
135
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
136
137
138
139
140
141
142
143
144
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
145
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
146
147
148
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
149
150
        window_size[0],
        window_size[1],
151
        deterministic,
Tri Dao's avatar
Tri Dao committed
152
153
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
154
155
156
157
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
175
    window_size,
176
    alibi_slopes,
177
    deterministic,
Tri Dao's avatar
Tri Dao committed
178
179
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
180
181
182
    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
183
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
184
185
186
187
188
189
190
191
192
193
194
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
195
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
196
197
198
199
200
201
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
202
203
        window_size[0],
        window_size[1],
204
        deterministic,
Tri Dao's avatar
Tri Dao committed
205
206
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
207
    )
Tri Dao's avatar
Tri Dao committed
208
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
209
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
210
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
211
212


Tri Dao's avatar
Tri Dao committed
213
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
214
    @staticmethod
Tri Dao's avatar
Tri Dao committed
215
    def forward(
216
217
218
219
220
221
222
223
224
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
225
226
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
227
    ):
Tri Dao's avatar
Tri Dao committed
228
229
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
230
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
237
            window_size=window_size,
238
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
239
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
240
            out=out,
Tri Dao's avatar
Tri Dao committed
241
        )
Tri Dao's avatar
Tri Dao committed
242
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
243
244
245
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
246
        ctx.window_size = window_size
247
        ctx.alibi_slopes = alibi_slopes
248
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
249
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
250
251

    @staticmethod
Tri Dao's avatar
Tri Dao committed
252
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
253
254
255
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
256
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
257
258
259
260
261
262
263
264
265
266
267
268
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
269
            ctx.window_size,
270
            ctx.alibi_slopes,
271
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
272
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
273
        )
Tri Dao's avatar
Tri Dao committed
274
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
275
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
276
277
278
279


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
280
281
282
283
284
285
286
287
288
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
289
        alibi_slopes,
290
        deterministic,
Tri Dao's avatar
Tri Dao committed
291
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
292
293
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
294
    ):
Tri Dao's avatar
Tri Dao committed
295
296
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
297
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
298
299
300
301
302
303
304
305
306
307
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
308
            window_size=window_size,
309
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
310
            return_softmax=return_softmax and dropout_p > 0,
311
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
312
            out=out,
Tri Dao's avatar
Tri Dao committed
313
        )
Tri Dao's avatar
Tri Dao committed
314
315
316
317
318
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
319
        ctx.window_size = window_size
320
        ctx.alibi_slopes = alibi_slopes
321
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
328
329
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
346
            ctx.window_size,
347
            ctx.alibi_slopes,
348
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
349
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
350
        )
Tri Dao's avatar
Tri Dao committed
351
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
352
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
353
354


Tri Dao's avatar
Tri Dao committed
355
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
356
    @staticmethod
Tri Dao's avatar
Tri Dao committed
357
    def forward(
358
359
360
361
362
363
364
365
366
367
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
368
        out=None,
Tri Dao's avatar
Tri Dao committed
369
    ):
Tri Dao's avatar
Tri Dao committed
370
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
371
            softmax_scale = q.shape[-1] ** (-0.5)
372
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
373
374
375
376
377
378
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
379
            window_size=window_size,
380
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
381
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
382
            out=out,
Tri Dao's avatar
Tri Dao committed
383
        )
Tri Dao's avatar
Tri Dao committed
384
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
385
386
387
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
388
        ctx.window_size = window_size
389
        ctx.alibi_slopes = alibi_slopes
390
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
391
392
393
394
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
395
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
396
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
397
398
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
399
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
406
407
408
409
410
411
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
412
            ctx.window_size,
413
            ctx.alibi_slopes,
414
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
415
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
416
        )
Tri Dao's avatar
Tri Dao committed
417
418
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
419
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
420
421


Tri Dao's avatar
Tri Dao committed
422
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
423
    @staticmethod
Tri Dao's avatar
Tri Dao committed
424
425
426
427
428
429
430
431
432
433
434
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
435
        window_size,
436
        alibi_slopes,
437
        deterministic,
Tri Dao's avatar
Tri Dao committed
438
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
439
        out=None,
Tri Dao's avatar
Tri Dao committed
440
    ):
Tri Dao's avatar
Tri Dao committed
441
442
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
443
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
444
445
446
447
448
449
450
451
452
453
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
454
            window_size=window_size,
455
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
456
            return_softmax=return_softmax and dropout_p > 0,
457
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
458
            out=out,
Tri Dao's avatar
Tri Dao committed
459
460
461
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
462
463
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
464
465
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
466
467
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
468
        ctx.window_size = window_size
469
        ctx.alibi_slopes = alibi_slopes
470
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
471
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
472
473

    @staticmethod
Tri Dao's avatar
Tri Dao committed
474
475
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
476
477
478
479
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
496
            ctx.window_size,
497
            ctx.alibi_slopes,
498
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
499
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
500
        )
Tri Dao's avatar
Tri Dao committed
501
502
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
503
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
504
505
506
507


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
508
    def forward(
509
510
511
512
513
514
515
516
517
518
519
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
520
        out=None,
Tri Dao's avatar
Tri Dao committed
521
    ):
Tri Dao's avatar
Tri Dao committed
522
523
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
524
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
525
526
527
528
529
530
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
531
            window_size=window_size,
532
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
533
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
534
            out=out,
Tri Dao's avatar
Tri Dao committed
535
536
537
538
539
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
540
        ctx.window_size = window_size
541
        ctx.alibi_slopes = alibi_slopes
542
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
543
544
545
546
547
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
548
549
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
550
551
552
553
554
555
556
557
558
559
560
561
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
562
            ctx.window_size,
563
            ctx.alibi_slopes,
564
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
565
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
566
        )
Tri Dao's avatar
Tri Dao committed
567
568
569
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
570
        return dq, dk, dv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
571
572


Tri Dao's avatar
Tri Dao committed
573
class FlashAttnVarlenFunc(torch.autograd.Function):
574
    @staticmethod
Tri Dao's avatar
Tri Dao committed
575
576
577
578
579
580
581
582
583
584
585
586
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
587
        window_size,
588
        alibi_slopes,
589
        deterministic,
Tri Dao's avatar
Tri Dao committed
590
        return_softmax,
591
        block_table,
Antoni Baum's avatar
Antoni Baum committed
592
        out=None,
Tri Dao's avatar
Tri Dao committed
593
    ):
594
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
595
            softmax_scale = q.shape[-1] ** (-0.5)
596
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
597
598
599
600
601
602
603
604
605
606
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
607
            window_size=window_size,
608
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
609
            return_softmax=return_softmax and dropout_p > 0,
610
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
611
            out=out,
Tri Dao's avatar
Tri Dao committed
612
613
614
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
615
616
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
617
618
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
619
620
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
621
        ctx.window_size = window_size
622
        ctx.alibi_slopes = alibi_slopes
623
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
624
        return out if not return_softmax else (out, softmax_lse, S_dmask)
625
626
627

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
628
629
630
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
647
            ctx.window_size,
648
            ctx.alibi_slopes,
649
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
650
            rng_state=rng_state,
651
        )
Tri Dao's avatar
Tri Dao committed
652
653
654
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
655
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
656
657


Tri Dao's avatar
Tri Dao committed
658
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
659
660
661
662
663
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
664
    alibi_slopes=None,
665
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
666
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
667
668
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
669
):
Tri Dao's avatar
Tri Dao committed
670
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
671
672
673
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
674
675
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
676

Tri Dao's avatar
Tri Dao committed
677
678
679
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
680
    Arguments:
Tri Dao's avatar
Tri Dao committed
681
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
682
683
684
685
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
686
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
687
688
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
689
690
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
691
692
693
694
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
695
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
696
697
698
699
700
701
702
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
703
    return FlashAttnQKVPackedFunc.apply(
704
705
706
707
708
709
710
711
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
712
        out,
Tri Dao's avatar
Tri Dao committed
713
    )
Tri Dao's avatar
Tri Dao committed
714
715


Tri Dao's avatar
Tri Dao committed
716
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
717
718
719
720
721
722
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
723
    alibi_slopes=None,
724
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
725
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
726
727
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
728
):
Tri Dao's avatar
Tri Dao committed
729
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
730
731
732
733
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
734
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
735
736
737
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

738
739
740
741
742
743
744
745
746
747
748
749
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
750
751
752
753
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
754
    Arguments:
Tri Dao's avatar
Tri Dao committed
755
756
757
758
759
760
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
761
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
762
763
764
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
765
766
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
767
768
769
770
771
772
773
774
775
776
777
778
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
779
    return FlashAttnKVPackedFunc.apply(
780
781
782
783
784
785
786
787
788
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
789
        out,
Tri Dao's avatar
Tri Dao committed
790
    )
Tri Dao's avatar
Tri Dao committed
791
792


Tri Dao's avatar
Tri Dao committed
793
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
794
795
796
797
798
799
800
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
801
    alibi_slopes=None,
802
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
803
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
804
805
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
806
):
Tri Dao's avatar
Tri Dao committed
807
808
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
809
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
810
811
812
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

813
814
815
816
817
818
819
820
821
822
823
824
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
825
826
827
828
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
829
830
831
832
833
834
835
836
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
837
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
838
839
840
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
841
842
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
843
844
845
846
847
848
849
850
851
852
853
854
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
855
    return FlashAttnFunc.apply(
856
857
858
859
860
861
862
863
864
865
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
866
        out,
Tri Dao's avatar
Tri Dao committed
867
    )
Tri Dao's avatar
Tri Dao committed
868
869


Tri Dao's avatar
Tri Dao committed
870
871
872
873
874
875
876
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
877
    window_size=(-1, -1),  # -1 means infinite context window
878
    alibi_slopes=None,
879
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
880
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
881
882
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
883
):
Tri Dao's avatar
Tri Dao committed
884
885
886
887
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
888
889
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
890

Tri Dao's avatar
Tri Dao committed
891
892
893
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
894
895
896
897
898
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
899
900
901
902
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
903
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
904
905
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
906
907
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
908
909
910
911
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
912
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
913
914
915
916
917
918
919
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
920
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
921
922
923
924
925
926
927
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
928
        alibi_slopes,
929
        deterministic,
Tri Dao's avatar
Tri Dao committed
930
        return_attn_probs,
931
        out,
Tri Dao's avatar
Tri Dao committed
932
    )
Tri Dao's avatar
Tri Dao committed
933
934


Tri Dao's avatar
Tri Dao committed
935
936
937
938
939
940
941
942
943
944
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
945
    window_size=(-1, -1),  # -1 means infinite context window
946
    alibi_slopes=None,
947
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
948
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
949
950
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
951
):
Tri Dao's avatar
Tri Dao committed
952
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
953
954
955
956
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
957
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
958
959
960
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

961
962
963
964
965
966
967
968
969
970
971
972
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
973
974
975
976
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
977
978
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
979
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
980
981
982
983
984
985
986
987
988
989
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
990
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
991
992
993
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
994
995
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
996
997
998
999
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1000
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
1001
1002
1003
1004
1005
1006
1007
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1008
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1009
1010
1011
1012
1013
1014
1015
1016
1017
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1018
        window_size,
1019
        alibi_slopes,
1020
        deterministic,
Tri Dao's avatar
Tri Dao committed
1021
        return_attn_probs,
1022
        out,
Tri Dao's avatar
Tri Dao committed
1023
    )
Tri Dao's avatar
Tri Dao committed
1024

1025

Tri Dao's avatar
Tri Dao committed
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1037
    window_size=(-1, -1),  # -1 means infinite context window
1038
    alibi_slopes=None,
1039
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1040
    return_attn_probs=False,
1041
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1042
1043
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1044
):
Tri Dao's avatar
Tri Dao committed
1045
1046
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1047
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1048
1049
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1050

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1063
1064
1065
1066
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1067
    Arguments:
Tri Dao's avatar
Tri Dao committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1077
1078
1079
1080
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1081
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1082
1083
1084
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1085
1086
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1099
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1110
        window_size,
1111
        alibi_slopes,
1112
        deterministic,
Tri Dao's avatar
Tri Dao committed
1113
        return_attn_probs,
1114
        block_table,
1115
        out,
Tri Dao's avatar
Tri Dao committed
1116
    )
Tri Dao's avatar
Tri Dao committed
1117
1118
1119
1120
1121
1122
1123
1124


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1125
1126
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1127
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1128
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1129
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1130
1131
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1132
    window_size=(-1, -1),  # -1 means infinite context window
1133
    rotary_interleaved=True,
1134
    alibi_slopes=None,
1135
    num_splits=0,
Antoni Baum's avatar
Antoni Baum committed
1136
1137
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1149
1150
1151
1152
1153
1154
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1155
1156

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1175
1176
1177
1178
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1179
1180
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1181
1182
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1183
1184
1185
1186
1187
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1188
1189
1190
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1191
1192
1193
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1194
1195
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1196
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1197
1198
1199
1200
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1201
1202
1203
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1204
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1205
1206
1207
1208
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1209
1210
1211
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1226
1227
1228
1229
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1230
1231
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1232
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1233
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1234
1235
1236
1237
1238
1239
1240
1241
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1242
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1243
        block_table,
1244
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1245
        out,
1246
1247
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1248
1249
        window_size[0],
        window_size[1],
1250
1251
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1252
1253
    )
    return out