"tests/configs/det_r50_vd_db.yml" did not exist on "6eb7dd5563f7f02732ff12f1e0778486e946d272"
flash_attn_interface.py 46.7 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
10
11
# Use relative import to support build-from-source installation in vLLM
from . import vllm_flash_attn_c # noqa: F401
Tri Dao's avatar
Tri Dao committed
12

13
14
# isort: on

15
16
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
17

Tri Dao's avatar
Tri Dao committed
18
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
26
        return 128
Tri Dao's avatar
Tri Dao committed
27
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
28
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
29
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
30
        return 64
Tri Dao's avatar
Tri Dao committed
31
32
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
33
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
34
        else:
Tri Dao's avatar
Tri Dao committed
35
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
36
37
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
38
            return 64
Tri Dao's avatar
Tri Dao committed
39
        else:
Tri Dao's avatar
Tri Dao committed
40
            return 32
Tri Dao's avatar
Tri Dao committed
41
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
42
        return 64
Tri Dao's avatar
Tri Dao committed
43
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
44
        return 64
Tri Dao's avatar
Tri Dao committed
45
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
46
        return 64
Tri Dao's avatar
Tri Dao committed
47
48


Tri Dao's avatar
Tri Dao committed
49
def _flash_attn_forward(
50
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
51
):
Tri Dao's avatar
Tri Dao committed
52
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
53
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd(
Tri Dao's avatar
Tri Dao committed
54
55
56
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
57
        out,
58
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
64
        softcap,
Tri Dao's avatar
Tri Dao committed
65
66
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
67
    )
68
69
70
    # NOTE(woosuk): out_padded, S_dmask, and rng_state are None
    # because we only use the forward pass in the vLLM.
    return out, q, k, v, out, softmax_lse, None, None
Tri Dao's avatar
Tri Dao committed
71
72


Tri Dao's avatar
Tri Dao committed
73
74
75
76
77
78
79
80
81
82
83
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
84
    window_size,
85
    softcap,
86
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
87
    return_softmax,
88
    block_table,
Antoni Baum's avatar
Antoni Baum committed
89
90
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
91
):
Tri Dao's avatar
Tri Dao committed
92
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
93
    out, softmax_lse = torch.ops.vllm_flash_attn_c.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
94
95
96
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
97
        out,
Tri Dao's avatar
Tri Dao committed
98
99
        cu_seqlens_q,
        cu_seqlens_k,
100
        None,
101
        block_table,
102
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
103
104
105
106
107
108
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
109
110
        window_size[0],
        window_size[1],
111
        softcap,
Tri Dao's avatar
Tri Dao committed
112
113
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
114
115
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
116
    #     breakpoint()
117
118
119
    # NOTE(woosuk): out_padded, S_dmask, and rng_state are None
    # because we only use the forward pass in the vLLM.
    return out, q, k, v, None, softmax_lse, None, None
Tri Dao's avatar
Tri Dao committed
120
121


Tri Dao's avatar
Tri Dao committed
122
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
123
124
125
126
127
128
129
130
131
132
133
134
135
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
136
    softcap,
137
    alibi_slopes,
138
    deterministic,
Tri Dao's avatar
Tri Dao committed
139
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
140
):
Tri Dao's avatar
Tri Dao committed
141
142
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
143
144
145
146
147
    (
        dq,
        dk,
        dv,
        softmax_d,
148
    ) = torch.ops.vllm_flash_attn_c.bwd(
Tri Dao's avatar
Tri Dao committed
149
150
151
152
153
154
155
156
157
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
158
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
159
160
161
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
162
163
        window_size[0],
        window_size[1],
164
        softcap,
165
        deterministic,
Tri Dao's avatar
Tri Dao committed
166
167
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
168
169
170
171
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
189
    window_size,
190
    softcap,
191
    alibi_slopes,
192
    deterministic,
Tri Dao's avatar
Tri Dao committed
193
194
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
195
196
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
197
198
199
200
201
    (
        dq,
        dk,
        dv,
        softmax_d,
202
    ) = torch.ops.vllm_flash_attn_c.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
203
204
205
206
207
208
209
210
211
212
213
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
214
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
220
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
221
222
        window_size[0],
        window_size[1],
223
        softcap,
224
        deterministic,
Tri Dao's avatar
Tri Dao committed
225
226
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
227
    )
Tri Dao's avatar
Tri Dao committed
228
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
229
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
230
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
231
232


Tri Dao's avatar
Tri Dao committed
233
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
234
    @staticmethod
Tri Dao's avatar
Tri Dao committed
235
    def forward(
236
237
238
239
240
241
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
242
        softcap,
243
244
245
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
246
247
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
248
    ):
Tri Dao's avatar
Tri Dao committed
249
250
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
251
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
252
253
254
255
256
257
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
258
            window_size=window_size,
259
            softcap=softcap,
260
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
261
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
262
            out=out,
Tri Dao's avatar
Tri Dao committed
263
        )
Tri Dao's avatar
Tri Dao committed
264
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
265
266
267
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
268
        ctx.window_size = window_size
269
        ctx.softcap = softcap
270
        ctx.alibi_slopes = alibi_slopes
271
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
272
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
273
274

    @staticmethod
Tri Dao's avatar
Tri Dao committed
275
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
276
277
278
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
279
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
280
281
282
283
284
285
286
287
288
289
290
291
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
292
            ctx.window_size,
293
            ctx.softcap,
294
            ctx.alibi_slopes,
295
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
296
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
297
        )
Tri Dao's avatar
Tri Dao committed
298
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
299
        return dqkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
300
301
302
303


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
304
305
306
307
308
309
310
311
312
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
313
        softcap,
314
        alibi_slopes,
315
        deterministic,
Tri Dao's avatar
Tri Dao committed
316
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
317
318
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
319
    ):
Tri Dao's avatar
Tri Dao committed
320
321
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
322
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
323
324
325
326
327
328
329
330
331
332
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
333
            window_size=window_size,
334
            softcap=softcap,
335
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
336
            return_softmax=return_softmax and dropout_p > 0,
337
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
338
            out=out,
Tri Dao's avatar
Tri Dao committed
339
        )
Tri Dao's avatar
Tri Dao committed
340
341
342
343
344
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
345
        ctx.window_size = window_size
346
        ctx.softcap = softcap
347
        ctx.alibi_slopes = alibi_slopes
348
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
354
355
356
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
373
            ctx.window_size,
374
            ctx.softcap,
375
            ctx.alibi_slopes,
376
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
377
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
378
        )
Tri Dao's avatar
Tri Dao committed
379
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
380
        return dqkv, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
381
382


Tri Dao's avatar
Tri Dao committed
383
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
384
    @staticmethod
Tri Dao's avatar
Tri Dao committed
385
    def forward(
386
387
388
389
390
391
392
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
393
        softcap,
394
395
396
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
397
        out=None,
Tri Dao's avatar
Tri Dao committed
398
    ):
Tri Dao's avatar
Tri Dao committed
399
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
400
            softmax_scale = q.shape[-1] ** (-0.5)
401
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
402
403
404
405
406
407
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
408
            window_size=window_size,
409
            softcap=softcap,
410
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
411
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
412
            out=out,
Tri Dao's avatar
Tri Dao committed
413
        )
Tri Dao's avatar
Tri Dao committed
414
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
415
416
417
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
418
        ctx.window_size = window_size
419
        ctx.softcap = softcap
420
        ctx.alibi_slopes = alibi_slopes
421
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
422
423
424
425
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
426
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
427
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
428
429
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
430
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
431
432
433
434
435
436
437
438
439
440
441
442
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
443
            ctx.window_size,
444
            ctx.softcap,
445
            ctx.alibi_slopes,
446
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
447
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
448
        )
Tri Dao's avatar
Tri Dao committed
449
450
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
451
        return dq, dkv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
452
453


Tri Dao's avatar
Tri Dao committed
454
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
455
    @staticmethod
Tri Dao's avatar
Tri Dao committed
456
457
458
459
460
461
462
463
464
465
466
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
467
        window_size,
468
        softcap,
469
        alibi_slopes,
470
        deterministic,
Tri Dao's avatar
Tri Dao committed
471
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
472
        out=None,
Tri Dao's avatar
Tri Dao committed
473
    ):
Tri Dao's avatar
Tri Dao committed
474
475
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
476
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
477
478
479
480
481
482
483
484
485
486
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
487
            window_size=window_size,
488
            softcap=softcap,
489
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
490
            return_softmax=return_softmax and dropout_p > 0,
491
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
492
            out=out,
Tri Dao's avatar
Tri Dao committed
493
494
495
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
496
497
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
498
499
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
500
501
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
502
        ctx.window_size = window_size
503
        ctx.softcap = softcap
504
        ctx.alibi_slopes = alibi_slopes
505
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
506
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
507
508

    @staticmethod
Tri Dao's avatar
Tri Dao committed
509
510
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
511
512
513
514
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
531
            ctx.window_size,
532
            ctx.softcap,
533
            ctx.alibi_slopes,
534
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
535
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
536
        )
Tri Dao's avatar
Tri Dao committed
537
538
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
539
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
540
541
542
543


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
544
    def forward(
545
546
547
548
549
550
551
552
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
553
        softcap,
554
555
556
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
557
        out=None,
Tri Dao's avatar
Tri Dao committed
558
    ):
Tri Dao's avatar
Tri Dao committed
559
560
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
561
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
562
563
564
565
566
567
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
568
            window_size=window_size,
569
            softcap=softcap,
570
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
571
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
572
            out=out,
Tri Dao's avatar
Tri Dao committed
573
574
575
576
577
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
578
        ctx.window_size = window_size
579
        ctx.softcap = softcap
580
        ctx.alibi_slopes = alibi_slopes
581
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
587
588
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
589
590
591
592
593
594
595
596
597
598
599
600
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
601
            ctx.window_size,
602
            ctx.softcap,
603
            ctx.alibi_slopes,
604
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
605
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
606
        )
Tri Dao's avatar
Tri Dao committed
607
608
609
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
610
        return dq, dk, dv, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
611
612


Tri Dao's avatar
Tri Dao committed
613
class FlashAttnVarlenFunc(torch.autograd.Function):
614
    @staticmethod
Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
620
621
622
623
624
625
626
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
627
        window_size,
628
        softcap,
629
        alibi_slopes,
630
        deterministic,
Tri Dao's avatar
Tri Dao committed
631
        return_softmax,
632
        block_table,
Antoni Baum's avatar
Antoni Baum committed
633
        out=None,
Tri Dao's avatar
Tri Dao committed
634
    ):
635
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
636
            softmax_scale = q.shape[-1] ** (-0.5)
637
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
638
639
640
641
642
643
644
645
646
647
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
648
            window_size=window_size,
649
            softcap=softcap,
650
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
651
            return_softmax=return_softmax and dropout_p > 0,
652
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
653
            out=out,
Tri Dao's avatar
Tri Dao committed
654
655
656
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
657
658
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
659
660
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
661
662
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
663
        ctx.window_size = window_size
664
        ctx.softcap = softcap
665
        ctx.alibi_slopes = alibi_slopes
666
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
667
        return out if not return_softmax else (out, softmax_lse, S_dmask)
668
669
670

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
671
672
673
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
690
            ctx.window_size,
691
            ctx.softcap,
692
            ctx.alibi_slopes,
693
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
694
            rng_state=rng_state,
695
        )
Tri Dao's avatar
Tri Dao committed
696
697
698
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
699
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
700
701


Tri Dao's avatar
Tri Dao committed
702
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
703
704
705
706
707
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
708
    softcap=0.0,  # <=0.0 means deactivate
709
    alibi_slopes=None,
710
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
711
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
712
713
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
714
):
Tri Dao's avatar
Tri Dao committed
715
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
716
717
718
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
719
720
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
721

Tri Dao's avatar
Tri Dao committed
722
723
724
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
725
    Arguments:
Tri Dao's avatar
Tri Dao committed
726
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
727
728
729
730
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
731
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
        softcap: float. Anything > 0 activates softcapping attention.
733
734
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
735
736
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
737
738
739
740
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
741
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
742
743
744
745
746
747
748
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
749
    return FlashAttnQKVPackedFunc.apply(
750
751
752
753
754
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
755
        softcap,
756
757
758
        alibi_slopes,
        deterministic,
        return_attn_probs,
759
        out,
Tri Dao's avatar
Tri Dao committed
760
    )
Tri Dao's avatar
Tri Dao committed
761
762


Tri Dao's avatar
Tri Dao committed
763
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
764
765
766
767
768
769
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
770
    softcap=0.0,  # 0.0 means deactivated
Tri Dao's avatar
Tri Dao committed
771
    alibi_slopes=None,
772
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
773
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
774
775
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
776
):
Tri Dao's avatar
Tri Dao committed
777
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
778
779
780
781
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
782
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
783
784
785
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

786
787
788
789
790
791
792
793
794
795
796
797
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
798
799
800
801
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
802
    Arguments:
Tri Dao's avatar
Tri Dao committed
803
804
805
806
807
808
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
809
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
810
        softcap: float. Anything > 0 activates softcapping attention.
811
812
813
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
814
815
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
816
817
818
819
820
821
822
823
824
825
826
827
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
828
    return FlashAttnKVPackedFunc.apply(
829
830
831
832
833
834
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
835
        softcap,
836
837
838
        alibi_slopes,
        deterministic,
        return_attn_probs,
839
        out,
Tri Dao's avatar
Tri Dao committed
840
    )
Tri Dao's avatar
Tri Dao committed
841
842


Tri Dao's avatar
Tri Dao committed
843
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
844
845
846
847
848
849
850
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
851
    softcap=0.0, # 0.0 means deactivated
852
    alibi_slopes=None,
853
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
854
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
855
856
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
857
):
Tri Dao's avatar
Tri Dao committed
858
859
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
860
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
861
862
863
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

864
865
866
867
868
869
870
871
872
873
874
875
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
876
877
878
879
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
880
881
882
883
884
885
886
887
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
888
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
889
890
891
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
892
893
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
894
895
896
897
898
899
900
901
902
903
904
905
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
906
    return FlashAttnFunc.apply(
907
908
909
910
911
912
913
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
914
        softcap,
915
916
917
        alibi_slopes,
        deterministic,
        return_attn_probs,
918
        out,
Tri Dao's avatar
Tri Dao committed
919
    )
Tri Dao's avatar
Tri Dao committed
920
921


Tri Dao's avatar
Tri Dao committed
922
923
924
925
926
927
928
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
929
    window_size=(-1, -1),  # -1 means infinite context window
930
    softcap=0.0, # 0.0 means deactivated
931
    alibi_slopes=None,
932
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
933
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
934
935
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
936
):
Tri Dao's avatar
Tri Dao committed
937
938
939
940
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
941
942
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
943

Tri Dao's avatar
Tri Dao committed
944
945
946
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
947
948
949
950
951
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
952
953
954
955
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
956
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
957
        softcap: float. Anything > 0 activates softcapping attention.
958
959
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
960
961
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
962
963
964
965
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
966
        out: (total, nheads, headdim).
967
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
968
969
970
971
972
973
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
974
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
975
976
977
978
979
980
981
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
982
        softcap,
983
        alibi_slopes,
984
        deterministic,
Tri Dao's avatar
Tri Dao committed
985
        return_attn_probs,
986
        out,
Tri Dao's avatar
Tri Dao committed
987
    )
Tri Dao's avatar
Tri Dao committed
988
989


Tri Dao's avatar
Tri Dao committed
990
991
992
993
994
995
996
997
998
999
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1000
    window_size=(-1, -1),  # -1 means infinite context window
1001
    softcap=0.0, # 0.0 means deactivated
1002
    alibi_slopes=None,
1003
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1004
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
1005
1006
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1007
):
Tri Dao's avatar
Tri Dao committed
1008
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
1009
1010
1011
1012
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1013
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1014
1015
1016
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1029
1030
1031
1032
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
1033
1034
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1035
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1046
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1047
        softcap: float. Anything > 0 activates softcapping attention.
1048
1049
1050
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1051
1052
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
1053
1054
1055
1056
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
1057
        out: (total, nheads, headdim).
1058
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Tri Dao's avatar
Tri Dao committed
1059
1060
1061
1062
1063
1064
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1065
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1066
1067
1068
1069
1070
1071
1072
1073
1074
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1075
        window_size,
1076
        softcap,
1077
        alibi_slopes,
1078
        deterministic,
Tri Dao's avatar
Tri Dao committed
1079
        return_attn_probs,
1080
        out,
Tri Dao's avatar
Tri Dao committed
1081
    )
Tri Dao's avatar
Tri Dao committed
1082

1083

Tri Dao's avatar
Tri Dao committed
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1095
    window_size=(-1, -1),  # -1 means infinite context window
1096
    softcap=0.0, # 0.0 means deactivated
1097
    alibi_slopes=None,
1098
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1099
    return_attn_probs=False,
1100
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1101
1102
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1103
):
Tri Dao's avatar
Tri Dao committed
1104
1105
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1106
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1107
1108
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1122
1123
1124
1125
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1126
    Arguments:
Tri Dao's avatar
Tri Dao committed
1127
1128
1129
1130
1131
1132
1133
1134
1135
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1136
1137
1138
1139
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1140
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1141
        softcap: float. Anything > 0 activates softcapping attention.
1142
1143
1144
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1145
1146
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1147
1148
1149
1150
1151
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
1152
        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1153
1154
1155
1156
1157
1158
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1159
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1170
        window_size,
1171
        softcap,
1172
        alibi_slopes,
1173
        deterministic,
Tri Dao's avatar
Tri Dao committed
1174
        return_attn_probs,
1175
        block_table,
1176
        out,
Tri Dao's avatar
Tri Dao committed
1177
    )
Tri Dao's avatar
Tri Dao committed
1178
1179
1180
1181
1182
1183
1184
1185


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1186
1187
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1188
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1189
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1190
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1191
1192
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1193
    window_size=(-1, -1),  # -1 means infinite context window
1194
    softcap=0.0, # 0.0 means deactivated
1195
    rotary_interleaved=True,
1196
    alibi_slopes=None,
1197
    num_splits=0,
1198
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
1199
1200
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1212
1213
1214
1215
1216
1217
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1218
1219

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1238
1239
1240
1241
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1242
1243
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1244
1245
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1246
1247
1248
1249
1250
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1251
1252
1253
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1254
1255
1256
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1257
1258
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1259
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1260
1261
1262
1263
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1264
1265
1266
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1267
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1268
        softcap: float. Anything > 0 activates softcapping attention.
1269
1270
1271
1272
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1273
1274
1275
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1276
1277
1278
1279
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
1280
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
1281
1282
1283

    Return:
        out: (batch_size, seqlen, nheads, headdim).
1284
1285
1286
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
1287
1288
1289
1290
1291
1292
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1293
1294
1295
1296
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1297
1298
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1299
    block_table = maybe_contiguous(block_table)
1300
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
1301
1302
1303
1304
1305
1306
1307
1308
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1309
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1310
        block_table,
1311
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1312
        out,
1313
1314
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1315
1316
        window_size[0],
        window_size[1],
1317
        softcap,
1318
1319
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1320
    )
1321
    return (out, softmax_lse) if return_softmax_lse else out