flash_attn_interface.py 46.7 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
10
11
# Use relative import to support build-from-source installation in vLLM
from . import vllm_flash_attn_c # noqa: F401
Tri Dao's avatar
Tri Dao committed
12

13
14
# isort: on

15
16
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
17

Tri Dao's avatar
Tri Dao committed
18
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
26
        return 128
Tri Dao's avatar
Tri Dao committed
27
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
28
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
29
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
30
        return 64
Tri Dao's avatar
Tri Dao committed
31
32
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
33
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
34
        else:
Tri Dao's avatar
Tri Dao committed
35
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
36
37
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
38
            return 64
Tri Dao's avatar
Tri Dao committed
39
        else:
Tri Dao's avatar
Tri Dao committed
40
            return 32
Tri Dao's avatar
Tri Dao committed
41
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
42
        return 64
Tri Dao's avatar
Tri Dao committed
43
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
44
        return 64
Tri Dao's avatar
Tri Dao committed
45
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
46
        return 64
Tri Dao's avatar
Tri Dao committed
47
48


Tri Dao's avatar
Tri Dao committed
49
def _flash_attn_forward(
50
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
51
):
Tri Dao's avatar
Tri Dao committed
52
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
53
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.fwd(
Tri Dao's avatar
Tri Dao committed
54
55
56
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
57
        out,
58
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
64
        softcap,
Tri Dao's avatar
Tri Dao committed
65
66
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
67
    )
68
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
69
70


Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
77
78
79
80
81
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
82
    window_size,
83
    softcap,
84
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
85
    return_softmax,
86
    block_table,
Antoni Baum's avatar
Antoni Baum committed
87
88
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
89
):
Tri Dao's avatar
Tri Dao committed
90
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
91
    out, softmax_lse = torch.ops.vllm_flash_attn_c.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
92
93
94
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
95
        out,
Tri Dao's avatar
Tri Dao committed
96
97
        cu_seqlens_q,
        cu_seqlens_k,
98
        None,
99
        block_table,
100
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
107
108
        window_size[0],
        window_size[1],
109
        softcap,
Tri Dao's avatar
Tri Dao committed
110
111
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
112
113
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
114
    #     breakpoint()
115
116
117
    # NOTE(woosuk): out_padded, S_dmask, and rng_state are None
    # because we only use the forward pass in the vLLM.
    return out, q, k, v, None, softmax_lse, None, None
Tri Dao's avatar
Tri Dao committed
118
119


Tri Dao's avatar
Tri Dao committed
120
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128
129
130
131
132
133
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
134
    softcap,
135
    alibi_slopes,
136
    deterministic,
Tri Dao's avatar
Tri Dao committed
137
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
138
):
Tri Dao's avatar
Tri Dao committed
139
140
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
141
142
143
144
145
    (
        dq,
        dk,
        dv,
        softmax_d,
146
    ) = torch.ops.vllm_flash_attn_c.bwd(
Tri Dao's avatar
Tri Dao committed
147
148
149
150
151
152
153
154
155
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
156
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
157
158
159
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
160
161
        window_size[0],
        window_size[1],
162
        softcap,
163
        deterministic,
Tri Dao's avatar
Tri Dao committed
164
165
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
166
167
168
169
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
187
    window_size,
188
    softcap,
189
    alibi_slopes,
190
    deterministic,
Tri Dao's avatar
Tri Dao committed
191
192
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
193
194
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
195
196
197
198
199
    (
        dq,
        dk,
        dv,
        softmax_d,
200
    ) = torch.ops.vllm_flash_attn_c.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
201
202
203
204
205
206
207
208
209
210
211
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
212
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
213
214
215
216
217
218
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
219
220
        window_size[0],
        window_size[1],
221
        softcap,
222
        deterministic,
Tri Dao's avatar
Tri Dao committed
223
224
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
225
    )
Tri Dao's avatar
Tri Dao committed
226
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
227
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
228
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
229
230


Tri Dao's avatar
Tri Dao committed
231
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
232
    @staticmethod
Tri Dao's avatar
Tri Dao committed
233
    def forward(
234
235
236
237
238
239
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
240
        softcap,
241
242
243
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
244
245
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
246
    ):
Tri Dao's avatar
Tri Dao committed
247
248
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
249
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
255
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
256
            window_size=window_size,
257
            softcap=softcap,
258
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
259
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
260
            out=out,
Tri Dao's avatar
Tri Dao committed
261
        )
Tri Dao's avatar
Tri Dao committed
262
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
263
264
265
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
266
        ctx.window_size = window_size
267
        ctx.softcap = softcap
268
        ctx.alibi_slopes = alibi_slopes
269
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
270
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
271
272

    @staticmethod
Tri Dao's avatar
Tri Dao committed
273
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
274
275
276
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
277
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
283
284
285
286
287
288
289
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
290
            ctx.window_size,
291
            ctx.softcap,
292
            ctx.alibi_slopes,
293
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
294
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
295
        )
Tri Dao's avatar
Tri Dao committed
296
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
297
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
298
299
300
301


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
302
303
304
305
306
307
308
309
310
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
311
        softcap,
312
        alibi_slopes,
313
        deterministic,
Tri Dao's avatar
Tri Dao committed
314
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
315
316
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
317
    ):
Tri Dao's avatar
Tri Dao committed
318
319
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
320
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
321
322
323
324
325
326
327
328
329
330
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
331
            window_size=window_size,
332
            softcap=softcap,
333
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
334
            return_softmax=return_softmax and dropout_p > 0,
335
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
336
            out=out,
Tri Dao's avatar
Tri Dao committed
337
        )
Tri Dao's avatar
Tri Dao committed
338
339
340
341
342
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
343
        ctx.window_size = window_size
344
        ctx.softcap = softcap
345
        ctx.alibi_slopes = alibi_slopes
346
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
347
348
349
350
351
352
353
354
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
371
            ctx.window_size,
372
            ctx.softcap,
373
            ctx.alibi_slopes,
374
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
375
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
376
        )
Tri Dao's avatar
Tri Dao committed
377
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
378
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
379
380


Tri Dao's avatar
Tri Dao committed
381
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
382
    @staticmethod
Tri Dao's avatar
Tri Dao committed
383
    def forward(
384
385
386
387
388
389
390
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
391
        softcap,
392
393
394
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
395
        out=None,
Tri Dao's avatar
Tri Dao committed
396
    ):
Tri Dao's avatar
Tri Dao committed
397
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
398
            softmax_scale = q.shape[-1] ** (-0.5)
399
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
406
            window_size=window_size,
407
            softcap=softcap,
408
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
409
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
410
            out=out,
Tri Dao's avatar
Tri Dao committed
411
        )
Tri Dao's avatar
Tri Dao committed
412
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
413
414
415
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
416
        ctx.window_size = window_size
417
        ctx.softcap = softcap
418
        ctx.alibi_slopes = alibi_slopes
419
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
420
421
422
423
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
424
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
425
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
426
427
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
428
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
429
430
431
432
433
434
435
436
437
438
439
440
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
441
            ctx.window_size,
442
            ctx.softcap,
443
            ctx.alibi_slopes,
444
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
445
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
446
        )
Tri Dao's avatar
Tri Dao committed
447
448
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
449
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
450
451


Tri Dao's avatar
Tri Dao committed
452
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
453
    @staticmethod
Tri Dao's avatar
Tri Dao committed
454
455
456
457
458
459
460
461
462
463
464
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
465
        window_size,
466
        softcap,
467
        alibi_slopes,
468
        deterministic,
Tri Dao's avatar
Tri Dao committed
469
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
470
        out=None,
Tri Dao's avatar
Tri Dao committed
471
    ):
Tri Dao's avatar
Tri Dao committed
472
473
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
474
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
475
476
477
478
479
480
481
482
483
484
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
485
            window_size=window_size,
486
            softcap=softcap,
487
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
488
            return_softmax=return_softmax and dropout_p > 0,
489
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
490
            out=out,
Tri Dao's avatar
Tri Dao committed
491
492
493
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
494
495
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
496
497
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
498
499
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
500
        ctx.window_size = window_size
501
        ctx.softcap = softcap
502
        ctx.alibi_slopes = alibi_slopes
503
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
504
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
505
506

    @staticmethod
Tri Dao's avatar
Tri Dao committed
507
508
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
509
510
511
512
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
529
            ctx.window_size,
530
            ctx.softcap,
531
            ctx.alibi_slopes,
532
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
533
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
534
        )
Tri Dao's avatar
Tri Dao committed
535
536
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
537
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
538
539
540
541


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
542
    def forward(
543
544
545
546
547
548
549
550
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
551
        softcap,
552
553
554
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
555
        out=None,
Tri Dao's avatar
Tri Dao committed
556
    ):
Tri Dao's avatar
Tri Dao committed
557
558
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
559
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
560
561
562
563
564
565
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
566
            window_size=window_size,
567
            softcap=softcap,
568
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
569
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
570
            out=out,
Tri Dao's avatar
Tri Dao committed
571
572
573
574
575
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
576
        ctx.window_size = window_size
577
        ctx.softcap = softcap
578
        ctx.alibi_slopes = alibi_slopes
579
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
580
581
582
583
584
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
585
586
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
587
588
589
590
591
592
593
594
595
596
597
598
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
599
            ctx.window_size,
600
            ctx.softcap,
601
            ctx.alibi_slopes,
602
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
603
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
604
        )
Tri Dao's avatar
Tri Dao committed
605
606
607
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
608
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
609
610


Tri Dao's avatar
Tri Dao committed
611
class FlashAttnVarlenFunc(torch.autograd.Function):
612
    @staticmethod
Tri Dao's avatar
Tri Dao committed
613
614
615
616
617
618
619
620
621
622
623
624
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
625
        window_size,
626
        softcap,
627
        alibi_slopes,
628
        deterministic,
Tri Dao's avatar
Tri Dao committed
629
        return_softmax,
630
        block_table,
Antoni Baum's avatar
Antoni Baum committed
631
        out=None,
Tri Dao's avatar
Tri Dao committed
632
    ):
633
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
634
            softmax_scale = q.shape[-1] ** (-0.5)
635
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
636
637
638
639
640
641
642
643
644
645
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
646
            window_size=window_size,
647
            softcap=softcap,
648
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
649
            return_softmax=return_softmax and dropout_p > 0,
650
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
651
            out=out,
Tri Dao's avatar
Tri Dao committed
652
653
654
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
655
656
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
657
658
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
659
660
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
661
        ctx.window_size = window_size
662
        ctx.softcap = softcap
663
        ctx.alibi_slopes = alibi_slopes
664
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
665
        return out if not return_softmax else (out, softmax_lse, S_dmask)
666
667
668

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
669
670
671
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
688
            ctx.window_size,
689
            ctx.softcap,
690
            ctx.alibi_slopes,
691
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
692
            rng_state=rng_state,
693
        )
Tri Dao's avatar
Tri Dao committed
694
695
696
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
697
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
698
699


Tri Dao's avatar
Tri Dao committed
700
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
701
702
703
704
705
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
706
    softcap=0.0,  # <=0.0 means deactivate
707
    alibi_slopes=None,
708
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
709
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
710
711
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
712
):
Tri Dao's avatar
Tri Dao committed
713
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
714
715
716
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
717
718
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
719

Tri Dao's avatar
Tri Dao committed
720
721
722
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
723
    Arguments:
Tri Dao's avatar
Tri Dao committed
724
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
725
726
727
728
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
729
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
730
        softcap: float. Anything > 0 activates softcapping attention.
731
732
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
733
734
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
735
736
737
738
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
739
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
740
741
742
743
744
745
746
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
747
    return FlashAttnQKVPackedFunc.apply(
748
749
750
751
752
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
753
        softcap,
754
755
756
        alibi_slopes,
        deterministic,
        return_attn_probs,
757
        out,
Tri Dao's avatar
Tri Dao committed
758
    )
Tri Dao's avatar
Tri Dao committed
759
760


Tri Dao's avatar
Tri Dao committed
761
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
762
763
764
765
766
767
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
768
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
769
    alibi_slopes=None,
770
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
771
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
772
773
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
774
):
Tri Dao's avatar
Tri Dao committed
775
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
776
777
778
779
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
780
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
781
782
783
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

784
785
786
787
788
789
790
791
792
793
794
795
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
796
797
798
799
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
800
    Arguments:
Tri Dao's avatar
Tri Dao committed
801
802
803
804
805
806
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
807
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
808
        softcap: float. Anything > 0 activates softcapping attention.
809
810
811
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
812
813
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
814
815
816
817
818
819
820
821
822
823
824
825
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
826
    return FlashAttnKVPackedFunc.apply(
827
828
829
830
831
832
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
833
        softcap,
834
835
836
        alibi_slopes,
        deterministic,
        return_attn_probs,
837
        out,
Tri Dao's avatar
Tri Dao committed
838
    )
Tri Dao's avatar
Tri Dao committed
839
840


Tri Dao's avatar
Tri Dao committed
841
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
842
843
844
845
846
847
848
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
849
    softcap=0.0, # 0.0 means deactivated
850
    alibi_slopes=None,
851
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
852
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
853
854
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
855
):
Tri Dao's avatar
Tri Dao committed
856
857
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
858
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
859
860
861
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

862
863
864
865
866
867
868
869
870
871
872
873
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
874
875
876
877
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
878
879
880
881
882
883
884
885
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
886
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
887
888
889
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
890
891
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
892
893
894
895
896
897
898
899
900
901
902
903
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
904
    return FlashAttnFunc.apply(
905
906
907
908
909
910
911
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
912
        softcap,
913
914
915
        alibi_slopes,
        deterministic,
        return_attn_probs,
916
        out,
Tri Dao's avatar
Tri Dao committed
917
    )
Tri Dao's avatar
Tri Dao committed
918
919


Tri Dao's avatar
Tri Dao committed
920
921
922
923
924
925
926
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
927
    window_size=(-1, -1),  # -1 means infinite context window
928
    softcap=0.0, # 0.0 means deactivated
929
    alibi_slopes=None,
930
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
931
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
932
933
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
934
):
Tri Dao's avatar
Tri Dao committed
935
936
937
938
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
939
940
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
941

Tri Dao's avatar
Tri Dao committed
942
943
944
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
945
946
947
948
949
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
950
951
952
953
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
954
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
955
        softcap: float. Anything > 0 activates softcapping attention.
956
957
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
958
959
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
960
961
962
963
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
964
        out: (total, nheads, headdim).
965
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
966
967
968
969
970
971
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
972
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
973
974
975
976
977
978
979
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
980
        softcap,
981
        alibi_slopes,
982
        deterministic,
Tri Dao's avatar
Tri Dao committed
983
        return_attn_probs,
984
        out,
Tri Dao's avatar
Tri Dao committed
985
    )
Tri Dao's avatar
Tri Dao committed
986
987


Tri Dao's avatar
Tri Dao committed
988
989
990
991
992
993
994
995
996
997
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
998
    window_size=(-1, -1),  # -1 means infinite context window
999
    softcap=0.0, # 0.0 means deactivated
1000
    alibi_slopes=None,
1001
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1002
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
1003
1004
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1005
):
Tri Dao's avatar
Tri Dao committed
1006
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
1007
1008
1009
1010
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1011
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1012
1013
1014
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1027
1028
1029
1030
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1031
1032
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1033
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1044
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1045
        softcap: float. Anything > 0 activates softcapping attention.
1046
1047
1048
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1049
1050
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1051
1052
1053
1054
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1055
        out: (total, nheads, headdim).
1056
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1057
1058
1059
1060
1061
1062
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1063
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1064
1065
1066
1067
1068
1069
1070
1071
1072
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1073
        window_size,
1074
        softcap,
1075
        alibi_slopes,
1076
        deterministic,
Tri Dao's avatar
Tri Dao committed
1077
        return_attn_probs,
1078
        out,
Tri Dao's avatar
Tri Dao committed
1079
    )
Tri Dao's avatar
Tri Dao committed
1080

1081

Tri Dao's avatar
Tri Dao committed
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1093
    window_size=(-1, -1),  # -1 means infinite context window
1094
    softcap=0.0, # 0.0 means deactivated
1095
    alibi_slopes=None,
1096
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1097
    return_attn_probs=False,
1098
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1099
1100
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1101
):
Tri Dao's avatar
Tri Dao committed
1102
1103
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1104
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1105
1106
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1107

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1120
1121
1122
1123
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1124
    Arguments:
Tri Dao's avatar
Tri Dao committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1134
1135
1136
1137
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1138
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1139
        softcap: float. Anything > 0 activates softcapping attention.
1140
1141
1142
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1143
1144
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1145
1146
1147
1148
1149
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1150
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1151
1152
1153
1154
1155
1156
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1157
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1168
        window_size,
1169
        softcap,
1170
        alibi_slopes,
1171
        deterministic,
Tri Dao's avatar
Tri Dao committed
1172
        return_attn_probs,
1173
        block_table,
1174
        out,
Tri Dao's avatar
Tri Dao committed
1175
    )
Tri Dao's avatar
Tri Dao committed
1176
1177
1178
1179
1180
1181
1182
1183


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1184
1185
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1186
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1187
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1188
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1189
1190
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1191
    window_size=(-1, -1),  # -1 means infinite context window
1192
    softcap=0.0, # 0.0 means deactivated
1193
    rotary_interleaved=True,
1194
    alibi_slopes=None,
1195
    num_splits=0,
1196
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
1197
1198
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1210
1211
1212
1213
1214
1215
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1216
1217

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1236
1237
1238
1239
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1240
1241
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1242
1243
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1244
1245
1246
1247
1248
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1249
1250
1251
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1252
1253
1254
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1255
1256
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1257
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1258
1259
1260
1261
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1262
1263
1264
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1265
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1266
        softcap: float. Anything > 0 activates softcapping attention.
1267
1268
1269
1270
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1271
1272
1273
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1274
1275
1276
1277
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1278
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1279
1280
1281

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1282
1283
1284
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1285
1286
1287
1288
1289
1290
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1291
1292
1293
1294
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1295
1296
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1297
    block_table = maybe_contiguous(block_table)
1298
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
1299
1300
1301
1302
1303
1304
1305
1306
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1307
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1308
        block_table,
1309
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1310
        out,
1311
1312
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1313
1314
        window_size[0],
        window_size[1],
1315
        softcap,
1316
1317
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1318
    )
1319
    return (out, softmax_lse) if return_softmax_lse else out