flash_attn_interface.py 46.6 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
10
11
# Use relative import to support build-from-source installation in vLLM
from . import vllm_flash_attn_c # noqa: F401
Tri Dao's avatar
Tri Dao committed
12

13
14
# isort: on

15
16
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
17

Tri Dao's avatar
Tri Dao committed
18
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
26
        return 128
Tri Dao's avatar
Tri Dao committed
27
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
28
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
29
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
30
        return 64
Tri Dao's avatar
Tri Dao committed
31
32
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
33
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
34
        else:
Tri Dao's avatar
Tri Dao committed
35
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
36
37
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
38
            return 64
Tri Dao's avatar
Tri Dao committed
39
        else:
Tri Dao's avatar
Tri Dao committed
40
            return 32
Tri Dao's avatar
Tri Dao committed
41
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
42
        return 64
Tri Dao's avatar
Tri Dao committed
43
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
44
        return 64
Tri Dao's avatar
Tri Dao committed
45
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
46
        return 64
Tri Dao's avatar
Tri Dao committed
47
48


Tri Dao's avatar
Tri Dao committed
49
def _flash_attn_forward(
50
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
51
):
Tri Dao's avatar
Tri Dao committed
52
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
53
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.fwd(
Tri Dao's avatar
Tri Dao committed
54
55
56
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
57
        out,
58
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
64
        softcap,
Tri Dao's avatar
Tri Dao committed
65
66
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
67
    )
68
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
69
70


Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
77
78
79
80
81
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
82
    window_size,
83
    softcap,
84
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
85
    return_softmax,
86
    block_table,
Antoni Baum's avatar
Antoni Baum committed
87
88
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
89
):
Tri Dao's avatar
Tri Dao committed
90
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
91
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
92
93
94
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
95
        out,
Tri Dao's avatar
Tri Dao committed
96
97
        cu_seqlens_q,
        cu_seqlens_k,
98
        None,
99
        block_table,
100
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
107
108
        window_size[0],
        window_size[1],
109
        softcap,
Tri Dao's avatar
Tri Dao committed
110
111
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
112
113
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
114
    #     breakpoint()
115
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
116
117


Tri Dao's avatar
Tri Dao committed
118
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
132
    softcap,
133
    alibi_slopes,
134
    deterministic,
Tri Dao's avatar
Tri Dao committed
135
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
136
):
Tri Dao's avatar
Tri Dao committed
137
138
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
139
140
141
142
143
    (
        dq,
        dk,
        dv,
        softmax_d,
144
    ) = torch.ops.vllm_flash_attn_c.bwd(
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
151
152
153
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
154
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
155
156
157
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
158
159
        window_size[0],
        window_size[1],
160
        softcap,
161
        deterministic,
Tri Dao's avatar
Tri Dao committed
162
163
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
164
165
166
167
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
185
    window_size,
186
    softcap,
187
    alibi_slopes,
188
    deterministic,
Tri Dao's avatar
Tri Dao committed
189
190
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
191
192
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
193
194
195
196
197
    (
        dq,
        dk,
        dv,
        softmax_d,
198
    ) = torch.ops.vllm_flash_attn_c.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
199
200
201
202
203
204
205
206
207
208
209
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
210
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
211
212
213
214
215
216
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
217
218
        window_size[0],
        window_size[1],
219
        softcap,
220
        deterministic,
Tri Dao's avatar
Tri Dao committed
221
222
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
223
    )
Tri Dao's avatar
Tri Dao committed
224
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
225
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
226
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
227
228


Tri Dao's avatar
Tri Dao committed
229
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
230
    @staticmethod
Tri Dao's avatar
Tri Dao committed
231
    def forward(
232
233
234
235
236
237
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
238
        softcap,
239
240
241
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
242
243
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
244
    ):
Tri Dao's avatar
Tri Dao committed
245
246
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
247
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
248
249
250
251
252
253
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
254
            window_size=window_size,
255
            softcap=softcap,
256
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
257
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
258
            out=out,
Tri Dao's avatar
Tri Dao committed
259
        )
Tri Dao's avatar
Tri Dao committed
260
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
261
262
263
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
264
        ctx.window_size = window_size
265
        ctx.softcap = softcap
266
        ctx.alibi_slopes = alibi_slopes
267
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
268
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
269
270

    @staticmethod
Tri Dao's avatar
Tri Dao committed
271
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
272
273
274
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
275
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
276
277
278
279
280
281
282
283
284
285
286
287
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
288
            ctx.window_size,
289
            ctx.softcap,
290
            ctx.alibi_slopes,
291
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
292
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
293
        )
Tri Dao's avatar
Tri Dao committed
294
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
295
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
296
297
298
299


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
300
301
302
303
304
305
306
307
308
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
309
        softcap,
310
        alibi_slopes,
311
        deterministic,
Tri Dao's avatar
Tri Dao committed
312
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
313
314
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
315
    ):
Tri Dao's avatar
Tri Dao committed
316
317
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
318
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
324
325
326
327
328
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
329
            window_size=window_size,
330
            softcap=softcap,
331
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
332
            return_softmax=return_softmax and dropout_p > 0,
333
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
334
            out=out,
Tri Dao's avatar
Tri Dao committed
335
        )
Tri Dao's avatar
Tri Dao committed
336
337
338
339
340
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
341
        ctx.window_size = window_size
342
        ctx.softcap = softcap
343
        ctx.alibi_slopes = alibi_slopes
344
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
345
346
347
348
349
350
351
352
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
369
            ctx.window_size,
370
            ctx.softcap,
371
            ctx.alibi_slopes,
372
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
373
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
374
        )
Tri Dao's avatar
Tri Dao committed
375
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
376
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
377
378


Tri Dao's avatar
Tri Dao committed
379
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
380
    @staticmethod
Tri Dao's avatar
Tri Dao committed
381
    def forward(
382
383
384
385
386
387
388
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
389
        softcap,
390
391
392
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
393
        out=None,
Tri Dao's avatar
Tri Dao committed
394
    ):
Tri Dao's avatar
Tri Dao committed
395
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
396
            softmax_scale = q.shape[-1] ** (-0.5)
397
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
404
            window_size=window_size,
405
            softcap=softcap,
406
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
407
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
408
            out=out,
Tri Dao's avatar
Tri Dao committed
409
        )
Tri Dao's avatar
Tri Dao committed
410
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
411
412
413
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
414
        ctx.window_size = window_size
415
        ctx.softcap = softcap
416
        ctx.alibi_slopes = alibi_slopes
417
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
418
419
420
421
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
422
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
423
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
424
425
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
426
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
427
428
429
430
431
432
433
434
435
436
437
438
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
439
            ctx.window_size,
440
            ctx.softcap,
441
            ctx.alibi_slopes,
442
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
443
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
444
        )
Tri Dao's avatar
Tri Dao committed
445
446
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
447
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
448
449


Tri Dao's avatar
Tri Dao committed
450
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
451
    @staticmethod
Tri Dao's avatar
Tri Dao committed
452
453
454
455
456
457
458
459
460
461
462
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
463
        window_size,
464
        softcap,
465
        alibi_slopes,
466
        deterministic,
Tri Dao's avatar
Tri Dao committed
467
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
468
        out=None,
Tri Dao's avatar
Tri Dao committed
469
    ):
Tri Dao's avatar
Tri Dao committed
470
471
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
472
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
473
474
475
476
477
478
479
480
481
482
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
483
            window_size=window_size,
484
            softcap=softcap,
485
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
486
            return_softmax=return_softmax and dropout_p > 0,
487
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
488
            out=out,
Tri Dao's avatar
Tri Dao committed
489
490
491
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
492
493
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
494
495
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
496
497
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
498
        ctx.window_size = window_size
499
        ctx.softcap = softcap
500
        ctx.alibi_slopes = alibi_slopes
501
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
502
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
503
504

    @staticmethod
Tri Dao's avatar
Tri Dao committed
505
506
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
507
508
509
510
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
527
            ctx.window_size,
528
            ctx.softcap,
529
            ctx.alibi_slopes,
530
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
531
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
532
        )
Tri Dao's avatar
Tri Dao committed
533
534
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
535
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
536
537
538
539


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
540
    def forward(
541
542
543
544
545
546
547
548
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
549
        softcap,
550
551
552
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
553
        out=None,
Tri Dao's avatar
Tri Dao committed
554
    ):
Tri Dao's avatar
Tri Dao committed
555
556
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
557
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
558
559
560
561
562
563
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
564
            window_size=window_size,
565
            softcap=softcap,
566
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
567
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
568
            out=out,
Tri Dao's avatar
Tri Dao committed
569
570
571
572
573
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
574
        ctx.window_size = window_size
575
        ctx.softcap = softcap
576
        ctx.alibi_slopes = alibi_slopes
577
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
578
579
580
581
582
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
583
584
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
585
586
587
588
589
590
591
592
593
594
595
596
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
597
            ctx.window_size,
598
            ctx.softcap,
599
            ctx.alibi_slopes,
600
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
601
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
602
        )
Tri Dao's avatar
Tri Dao committed
603
604
605
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
606
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
607
608


Tri Dao's avatar
Tri Dao committed
609
class FlashAttnVarlenFunc(torch.autograd.Function):
610
    @staticmethod
Tri Dao's avatar
Tri Dao committed
611
612
613
614
615
616
617
618
619
620
621
622
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
623
        window_size,
624
        softcap,
625
        alibi_slopes,
626
        deterministic,
Tri Dao's avatar
Tri Dao committed
627
        return_softmax,
628
        block_table,
Antoni Baum's avatar
Antoni Baum committed
629
        out=None,
Tri Dao's avatar
Tri Dao committed
630
    ):
631
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
632
            softmax_scale = q.shape[-1] ** (-0.5)
633
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
634
635
636
637
638
639
640
641
642
643
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
644
            window_size=window_size,
645
            softcap=softcap,
646
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
647
            return_softmax=return_softmax and dropout_p > 0,
648
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
649
            out=out,
Tri Dao's avatar
Tri Dao committed
650
651
652
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
653
654
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
655
656
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
657
658
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
659
        ctx.window_size = window_size
660
        ctx.softcap = softcap
661
        ctx.alibi_slopes = alibi_slopes
662
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
663
        return out if not return_softmax else (out, softmax_lse, S_dmask)
664
665
666

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
667
668
669
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
686
            ctx.window_size,
687
            ctx.softcap,
688
            ctx.alibi_slopes,
689
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
690
            rng_state=rng_state,
691
        )
Tri Dao's avatar
Tri Dao committed
692
693
694
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
695
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
696
697


Tri Dao's avatar
Tri Dao committed
698
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
699
700
701
702
703
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
704
    softcap=0.0,  # <=0.0 means deactivate
705
    alibi_slopes=None,
706
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
707
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
708
709
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
710
):
Tri Dao's avatar
Tri Dao committed
711
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
712
713
714
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
715
716
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
717

Tri Dao's avatar
Tri Dao committed
718
719
720
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
721
    Arguments:
Tri Dao's avatar
Tri Dao committed
722
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
723
724
725
726
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
727
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
728
        softcap: float. Anything > 0 activates softcapping attention.
729
730
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
731
732
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
733
734
735
736
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
737
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
738
739
740
741
742
743
744
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
745
    return FlashAttnQKVPackedFunc.apply(
746
747
748
749
750
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
751
        softcap,
752
753
754
        alibi_slopes,
        deterministic,
        return_attn_probs,
755
        out,
Tri Dao's avatar
Tri Dao committed
756
    )
Tri Dao's avatar
Tri Dao committed
757
758


Tri Dao's avatar
Tri Dao committed
759
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
760
761
762
763
764
765
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
766
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
767
    alibi_slopes=None,
768
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
769
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
770
771
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
772
):
Tri Dao's avatar
Tri Dao committed
773
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
774
775
776
777
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
778
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
779
780
781
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

782
783
784
785
786
787
788
789
790
791
792
793
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
794
795
796
797
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
798
    Arguments:
Tri Dao's avatar
Tri Dao committed
799
800
801
802
803
804
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
805
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
806
        softcap: float. Anything > 0 activates softcapping attention.
807
808
809
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
810
811
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
812
813
814
815
816
817
818
819
820
821
822
823
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
824
    return FlashAttnKVPackedFunc.apply(
825
826
827
828
829
830
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
831
        softcap,
832
833
834
        alibi_slopes,
        deterministic,
        return_attn_probs,
835
        out,
Tri Dao's avatar
Tri Dao committed
836
    )
Tri Dao's avatar
Tri Dao committed
837
838


Tri Dao's avatar
Tri Dao committed
839
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
840
841
842
843
844
845
846
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
847
    softcap=0.0, # 0.0 means deactivated
848
    alibi_slopes=None,
849
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
850
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
851
852
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
853
):
Tri Dao's avatar
Tri Dao committed
854
855
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
856
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
857
858
859
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

860
861
862
863
864
865
866
867
868
869
870
871
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
872
873
874
875
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
876
877
878
879
880
881
882
883
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
884
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
885
886
887
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
888
889
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
890
891
892
893
894
895
896
897
898
899
900
901
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
902
    return FlashAttnFunc.apply(
903
904
905
906
907
908
909
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
910
        softcap,
911
912
913
        alibi_slopes,
        deterministic,
        return_attn_probs,
914
        out,
Tri Dao's avatar
Tri Dao committed
915
    )
Tri Dao's avatar
Tri Dao committed
916
917


Tri Dao's avatar
Tri Dao committed
918
919
920
921
922
923
924
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
925
    window_size=(-1, -1),  # -1 means infinite context window
926
    softcap=0.0, # 0.0 means deactivated
927
    alibi_slopes=None,
928
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
929
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
930
931
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
932
):
Tri Dao's avatar
Tri Dao committed
933
934
935
936
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
937
938
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
939

Tri Dao's avatar
Tri Dao committed
940
941
942
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
943
944
945
946
947
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
948
949
950
951
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
952
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
953
        softcap: float. Anything > 0 activates softcapping attention.
954
955
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
956
957
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
958
959
960
961
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
962
        out: (total, nheads, headdim).
963
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
964
965
966
967
968
969
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
970
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
971
972
973
974
975
976
977
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
978
        softcap,
979
        alibi_slopes,
980
        deterministic,
Tri Dao's avatar
Tri Dao committed
981
        return_attn_probs,
982
        out,
Tri Dao's avatar
Tri Dao committed
983
    )
Tri Dao's avatar
Tri Dao committed
984
985


Tri Dao's avatar
Tri Dao committed
986
987
988
989
990
991
992
993
994
995
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
996
    window_size=(-1, -1),  # -1 means infinite context window
997
    softcap=0.0, # 0.0 means deactivated
998
    alibi_slopes=None,
999
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1000
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
1001
1002
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1003
):
Tri Dao's avatar
Tri Dao committed
1004
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1009
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1010
1011
1012
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1025
1026
1027
1028
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1029
1030
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1031
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1042
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1043
        softcap: float. Anything > 0 activates softcapping attention.
1044
1045
1046
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1047
1048
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1049
1050
1051
1052
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1053
        out: (total, nheads, headdim).
1054
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1055
1056
1057
1058
1059
1060
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1061
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1062
1063
1064
1065
1066
1067
1068
1069
1070
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1071
        window_size,
1072
        softcap,
1073
        alibi_slopes,
1074
        deterministic,
Tri Dao's avatar
Tri Dao committed
1075
        return_attn_probs,
1076
        out,
Tri Dao's avatar
Tri Dao committed
1077
    )
Tri Dao's avatar
Tri Dao committed
1078

1079

Tri Dao's avatar
Tri Dao committed
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1091
    window_size=(-1, -1),  # -1 means infinite context window
1092
    softcap=0.0, # 0.0 means deactivated
1093
    alibi_slopes=None,
1094
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1095
    return_attn_probs=False,
1096
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1097
1098
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1099
):
Tri Dao's avatar
Tri Dao committed
1100
1101
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1102
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1103
1104
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1105

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1118
1119
1120
1121
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1122
    Arguments:
Tri Dao's avatar
Tri Dao committed
1123
1124
1125
1126
1127
1128
1129
1130
1131
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1132
1133
1134
1135
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1136
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1137
        softcap: float. Anything > 0 activates softcapping attention.
1138
1139
1140
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1141
1142
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1143
1144
1145
1146
1147
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1148
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1149
1150
1151
1152
1153
1154
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1155
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1166
        window_size,
1167
        softcap,
1168
        alibi_slopes,
1169
        deterministic,
Tri Dao's avatar
Tri Dao committed
1170
        return_attn_probs,
1171
        block_table,
1172
        out,
Tri Dao's avatar
Tri Dao committed
1173
    )
Tri Dao's avatar
Tri Dao committed
1174
1175
1176
1177
1178
1179
1180
1181


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1182
1183
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1184
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1185
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1186
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1187
1188
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1189
    window_size=(-1, -1),  # -1 means infinite context window
1190
    softcap=0.0, # 0.0 means deactivated
1191
    rotary_interleaved=True,
1192
    alibi_slopes=None,
1193
    num_splits=0,
1194
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
1195
1196
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1208
1209
1210
1211
1212
1213
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1214
1215

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1234
1235
1236
1237
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1238
1239
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1240
1241
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1242
1243
1244
1245
1246
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1247
1248
1249
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1250
1251
1252
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1253
1254
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1255
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1256
1257
1258
1259
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1260
1261
1262
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1263
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1264
        softcap: float. Anything > 0 activates softcapping attention.
1265
1266
1267
1268
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1269
1270
1271
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1272
1273
1274
1275
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1276
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1277
1278
1279

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1280
1281
1282
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1283
1284
1285
1286
1287
1288
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1289
1290
1291
1292
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1293
1294
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1295
    block_table = maybe_contiguous(block_table)
1296
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
1297
1298
1299
1300
1301
1302
1303
1304
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1305
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1306
        block_table,
1307
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1308
        out,
1309
1310
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1311
1312
        window_size[0],
        window_size[1],
1313
        softcap,
1314
1315
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1316
    )
1317
    return (out, softmax_lse) if return_softmax_lse else out