flash_attn_interface.py 44.5 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
Woosuk Kwon's avatar
Woosuk Kwon committed
10
import vllm_flash_attn_2_cuda as flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
11

12
13
# isort: on

14
15
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
16

Tri Dao's avatar
Tri Dao committed
17
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
24
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
25
        return 128
Tri Dao's avatar
Tri Dao committed
26
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
27
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
28
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
29
        return 64
Tri Dao's avatar
Tri Dao committed
30
31
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
32
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
33
        else:
Tri Dao's avatar
Tri Dao committed
34
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
35
36
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
37
            return 64
Tri Dao's avatar
Tri Dao committed
38
        else:
Tri Dao's avatar
Tri Dao committed
39
            return 32
Tri Dao's avatar
Tri Dao committed
40
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
41
        return 64
Tri Dao's avatar
Tri Dao committed
42
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
43
        return 64
Tri Dao's avatar
Tri Dao committed
44
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
45
        return 64
Tri Dao's avatar
Tri Dao committed
46
47


Tri Dao's avatar
Tri Dao committed
48
def _flash_attn_forward(
Antoni Baum's avatar
Antoni Baum committed
49
    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
50
):
Tri Dao's avatar
Tri Dao committed
51
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
52
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
Tri Dao's avatar
Tri Dao committed
53
54
55
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
56
        out,
57
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63
64
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
65
    )
66
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
67
68


Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
80
    window_size,
81
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
82
    return_softmax,
83
    block_table,
Antoni Baum's avatar
Antoni Baum committed
84
85
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
86
):
Tri Dao's avatar
Tri Dao committed
87
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
88
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
89
90
91
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
92
        out,
Tri Dao's avatar
Tri Dao committed
93
94
        cu_seqlens_q,
        cu_seqlens_k,
95
        None,
96
        block_table,
97
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
104
105
        window_size[0],
        window_size[1],
Tri Dao's avatar
Tri Dao committed
106
107
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
108
109
    )
    # if out.isnan().any() or softmax_lse.isnan().any():
Tri Dao's avatar
Tri Dao committed
110
    #     breakpoint()
111
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
Tri Dao's avatar
Tri Dao committed
112
113


Tri Dao's avatar
Tri Dao committed
114
def _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
115
116
117
118
119
120
121
122
123
124
125
126
127
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
128
    alibi_slopes,
129
    deterministic,
Tri Dao's avatar
Tri Dao committed
130
    rng_state=None,
Tri Dao's avatar
Tri Dao committed
131
):
Tri Dao's avatar
Tri Dao committed
132
133
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
134
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
Tri Dao's avatar
Tri Dao committed
135
136
137
138
139
140
141
142
143
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
144
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
145
146
147
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
148
149
        window_size[0],
        window_size[1],
150
        deterministic,
Tri Dao's avatar
Tri Dao committed
151
152
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
153
154
155
156
    )
    return dq, dk, dv, softmax_d


Tri Dao's avatar
Tri Dao committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def _flash_attn_varlen_backward(
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    dq,
    dk,
    dv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
174
    window_size,
175
    alibi_slopes,
176
    deterministic,
Tri Dao's avatar
Tri Dao committed
177
178
    rng_state=None,
):
Tri Dao's avatar
Tri Dao committed
179
180
    # dq, dk, dv are allocated by us so they should already be contiguous
    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Tri Dao's avatar
Tri Dao committed
181
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Tri Dao's avatar
Tri Dao committed
182
183
184
185
186
187
188
189
190
191
192
        dout,
        q,
        k,
        v,
        out,
        softmax_lse,
        dq,
        dk,
        dv,
        cu_seqlens_q,
        cu_seqlens_k,
193
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
194
195
196
197
198
199
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
200
201
        window_size[0],
        window_size[1],
202
        deterministic,
Tri Dao's avatar
Tri Dao committed
203
204
        None,
        rng_state,
Tri Dao's avatar
Tri Dao committed
205
    )
Tri Dao's avatar
Tri Dao committed
206
    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Tri Dao's avatar
Tri Dao committed
207
    #     breakpoint()
Tri Dao's avatar
Tri Dao committed
208
    return dq, dk, dv, softmax_d
Tri Dao's avatar
Tri Dao committed
209
210


Tri Dao's avatar
Tri Dao committed
211
class FlashAttnQKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
212
    @staticmethod
Tri Dao's avatar
Tri Dao committed
213
    def forward(
214
215
216
217
218
219
220
221
222
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
223
224
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
225
    ):
Tri Dao's avatar
Tri Dao committed
226
227
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
228
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
229
230
231
232
233
234
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
235
            window_size=window_size,
236
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
237
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
238
            out=out,
Tri Dao's avatar
Tri Dao committed
239
        )
Tri Dao's avatar
Tri Dao committed
240
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
241
242
243
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
244
        ctx.window_size = window_size
245
        ctx.alibi_slopes = alibi_slopes
246
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
247
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
248
249

    @staticmethod
Tri Dao's avatar
Tri Dao committed
250
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
251
252
253
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
Tri Dao's avatar
Tri Dao committed
254
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
255
256
257
258
259
260
261
262
263
264
265
266
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, :, 0],
            dqkv[:, :, 1],
            dqkv[:, :, 2],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
267
            ctx.window_size,
268
            ctx.alibi_slopes,
269
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
270
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
271
        )
Tri Dao's avatar
Tri Dao committed
272
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
273
        return dqkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
274
275
276
277


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
283
284
285
286
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
287
        alibi_slopes,
288
        deterministic,
Tri Dao's avatar
Tri Dao committed
289
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
290
291
        *,
        out=None,
Tri Dao's avatar
Tri Dao committed
292
    ):
Tri Dao's avatar
Tri Dao committed
293
294
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
295
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
305
            qkv[:, 0],
            qkv[:, 1],
            qkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            max_seqlen,
            max_seqlen,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
306
            window_size=window_size,
307
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
308
            return_softmax=return_softmax and dropout_p > 0,
309
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
310
            out=out,
Tri Dao's avatar
Tri Dao committed
311
        )
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_seqlen = max_seqlen
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
317
        ctx.window_size = window_size
318
        ctx.alibi_slopes = alibi_slopes
319
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
320
321
322
323
324
325
326
327
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dqkv[:, 0],
            dqkv[:, 1],
            dqkv[:, 2],
            cu_seqlens,
            cu_seqlens,
            ctx.max_seqlen,
            ctx.max_seqlen,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
344
            ctx.window_size,
345
            ctx.alibi_slopes,
346
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
347
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
348
        )
Tri Dao's avatar
Tri Dao committed
349
        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
350
        return dqkv, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
351
352


Tri Dao's avatar
Tri Dao committed
353
class FlashAttnKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
354
    @staticmethod
Tri Dao's avatar
Tri Dao committed
355
    def forward(
356
357
358
359
360
361
362
363
364
365
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
366
        out=None,
Tri Dao's avatar
Tri Dao committed
367
    ):
Tri Dao's avatar
Tri Dao committed
368
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
369
            softmax_scale = q.shape[-1] ** (-0.5)
370
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
371
372
373
374
375
376
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
377
            window_size=window_size,
378
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
379
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
380
            out=out,
Tri Dao's avatar
Tri Dao committed
381
        )
Tri Dao's avatar
Tri Dao committed
382
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
Tri Dao's avatar
Tri Dao committed
383
384
385
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
386
        ctx.window_size = window_size
387
        ctx.alibi_slopes = alibi_slopes
388
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
389
390
391
392
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
393
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
394
        dq = torch.empty_like(q)
Tri Dao's avatar
Tri Dao committed
395
396
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
Tri Dao's avatar
Tri Dao committed
397
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
404
405
406
407
408
409
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, :, 0],
            dkv[:, :, 1],
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
410
            ctx.window_size,
411
            ctx.alibi_slopes,
412
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
413
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
414
        )
Tri Dao's avatar
Tri Dao committed
415
416
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
417
        return dq, dkv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
418
419


Tri Dao's avatar
Tri Dao committed
420
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
421
    @staticmethod
Tri Dao's avatar
Tri Dao committed
422
423
424
425
426
427
428
429
430
431
432
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
433
        window_size,
434
        alibi_slopes,
435
        deterministic,
Tri Dao's avatar
Tri Dao committed
436
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
437
        out=None,
Tri Dao's avatar
Tri Dao committed
438
    ):
Tri Dao's avatar
Tri Dao committed
439
440
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
441
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
442
443
444
445
446
447
448
449
450
451
            q,
            kv[:, 0],
            kv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
452
            window_size=window_size,
453
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
454
            return_softmax=return_softmax and dropout_p > 0,
455
            block_table=None,
Antoni Baum's avatar
Antoni Baum committed
456
            out=out,
Tri Dao's avatar
Tri Dao committed
457
458
459
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Tri Dao's avatar
Tri Dao committed
460
461
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
462
463
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
Tri Dao's avatar
Tri Dao committed
464
465
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
466
        ctx.window_size = window_size
467
        ctx.alibi_slopes = alibi_slopes
468
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
469
        return out if not return_softmax else (out, softmax_lse, S_dmask)
Tri Dao's avatar
Tri Dao committed
470
471

    @staticmethod
Tri Dao's avatar
Tri Dao committed
472
473
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
474
475
476
477
        dq = torch.empty_like(q)
        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dkv[:, 0],
            dkv[:, 1],
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
494
            ctx.window_size,
495
            ctx.alibi_slopes,
496
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
497
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
498
        )
Tri Dao's avatar
Tri Dao committed
499
500
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dkv = dkv[..., : dout.shape[-1]]
501
        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
502
503
504
505


class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
506
    def forward(
507
508
509
510
511
512
513
514
515
516
517
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_softmax,
Antoni Baum's avatar
Antoni Baum committed
518
        out=None,
Tri Dao's avatar
Tri Dao committed
519
    ):
Tri Dao's avatar
Tri Dao committed
520
521
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
522
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
523
524
525
526
527
528
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
529
            window_size=window_size,
530
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
531
            return_softmax=return_softmax and dropout_p > 0,
Antoni Baum's avatar
Antoni Baum committed
532
            out=out,
Tri Dao's avatar
Tri Dao committed
533
534
535
536
537
        )
        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
538
        ctx.window_size = window_size
539
        ctx.alibi_slopes = alibi_slopes
540
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
541
542
543
544
545
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
546
547
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
548
549
550
551
552
553
554
555
556
557
558
559
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
560
            ctx.window_size,
561
            ctx.alibi_slopes,
562
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
563
            rng_state=rng_state,
Tri Dao's avatar
Tri Dao committed
564
        )
Tri Dao's avatar
Tri Dao committed
565
566
567
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
568
        return dq, dk, dv, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
569
570


Tri Dao's avatar
Tri Dao committed
571
class FlashAttnVarlenFunc(torch.autograd.Function):
572
    @staticmethod
Tri Dao's avatar
Tri Dao committed
573
574
575
576
577
578
579
580
581
582
583
584
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
585
        window_size,
586
        alibi_slopes,
587
        deterministic,
Tri Dao's avatar
Tri Dao committed
588
        return_softmax,
589
        block_table,
Antoni Baum's avatar
Antoni Baum committed
590
        out=None,
Tri Dao's avatar
Tri Dao committed
591
    ):
592
        if softmax_scale is None:
Tri Dao's avatar
Tri Dao committed
593
            softmax_scale = q.shape[-1] ** (-0.5)
594
        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
595
596
597
598
599
600
601
602
603
604
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            softmax_scale,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
605
            window_size=window_size,
606
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
607
            return_softmax=return_softmax and dropout_p > 0,
608
            block_table=block_table,
Antoni Baum's avatar
Antoni Baum committed
609
            out=out,
Tri Dao's avatar
Tri Dao committed
610
611
612
        )
        ctx.save_for_backward(
            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
613
614
        )
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
615
616
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
617
618
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
Tri Dao's avatar
Tri Dao committed
619
        ctx.window_size = window_size
620
        ctx.alibi_slopes = alibi_slopes
621
        ctx.deterministic = deterministic
Tri Dao's avatar
Tri Dao committed
622
        return out if not return_softmax else (out, softmax_lse, S_dmask)
623
624
625

    @staticmethod
    def backward(ctx, dout, *args):
Tri Dao's avatar
Tri Dao committed
626
627
628
        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        _flash_attn_varlen_backward(
Tri Dao's avatar
Tri Dao committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_k,
            ctx.max_seqlen_q,
            ctx.max_seqlen_k,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
Tri Dao's avatar
Tri Dao committed
645
            ctx.window_size,
646
            ctx.alibi_slopes,
647
            ctx.deterministic,
Tri Dao's avatar
Tri Dao committed
648
            rng_state=rng_state,
649
        )
Tri Dao's avatar
Tri Dao committed
650
651
652
        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
        dk = dk[..., : dout.shape[-1]]
        dv = dv[..., : dout.shape[-1]]
653
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
654
655


Tri Dao's avatar
Tri Dao committed
656
def flash_attn_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
657
658
659
660
661
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
662
    alibi_slopes=None,
663
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
664
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
665
666
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
667
):
Tri Dao's avatar
Tri Dao committed
668
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
669
670
671
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
672
673
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.
Tri Dao's avatar
Tri Dao committed
674

Tri Dao's avatar
Tri Dao committed
675
676
677
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
678
    Arguments:
Tri Dao's avatar
Tri Dao committed
679
        qkv: (batch_size, seqlen, 3, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
680
681
682
683
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
684
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
685
686
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
687
688
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
689
690
691
692
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
693
        out: (batch_size, seqlen, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
694
695
696
697
698
699
700
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
701
    return FlashAttnQKVPackedFunc.apply(
702
703
704
705
706
707
708
709
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
710
        out,
Tri Dao's avatar
Tri Dao committed
711
    )
Tri Dao's avatar
Tri Dao committed
712
713


Tri Dao's avatar
Tri Dao committed
714
def flash_attn_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
715
716
717
718
719
720
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
Tri Dao's avatar
Tri Dao committed
721
    alibi_slopes=None,
722
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
723
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
724
725
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
726
):
Tri Dao's avatar
Tri Dao committed
727
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
728
729
730
731
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
732
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
733
734
735
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

736
737
738
739
740
741
742
743
744
745
746
747
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
748
749
750
751
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
752
    Arguments:
Tri Dao's avatar
Tri Dao committed
753
754
755
756
757
758
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
759
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
760
761
762
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
763
764
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
765
766
767
768
769
770
771
772
773
774
775
776
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
777
    return FlashAttnKVPackedFunc.apply(
778
779
780
781
782
783
784
785
786
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
787
        out,
Tri Dao's avatar
Tri Dao committed
788
    )
Tri Dao's avatar
Tri Dao committed
789
790


Tri Dao's avatar
Tri Dao committed
791
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
792
793
794
795
796
797
798
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
799
    alibi_slopes=None,
800
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
801
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
802
803
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
804
):
Tri Dao's avatar
Tri Dao committed
805
806
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
807
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
808
809
810
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

811
812
813
814
815
816
817
818
819
820
821
822
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
823
824
825
826
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
827
828
829
830
831
832
833
834
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
835
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
836
837
838
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
839
840
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
841
842
843
844
845
846
847
848
849
850
851
852
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
853
    return FlashAttnFunc.apply(
854
855
856
857
858
859
860
861
862
863
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
864
        out,
Tri Dao's avatar
Tri Dao committed
865
    )
Tri Dao's avatar
Tri Dao committed
866
867


Tri Dao's avatar
Tri Dao committed
868
869
870
871
872
873
874
def flash_attn_varlen_qkvpacked_func(
    qkv,
    cu_seqlens,
    max_seqlen,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
875
    window_size=(-1, -1),  # -1 means infinite context window
876
    alibi_slopes=None,
877
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
878
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
879
880
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
881
):
Tri Dao's avatar
Tri Dao committed
882
883
884
885
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
886
887
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Tri Dao's avatar
Tri Dao committed
888

Tri Dao's avatar
Tri Dao committed
889
890
891
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
892
893
894
895
896
    Arguments:
        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.
Tri Dao's avatar
Tri Dao committed
897
898
899
900
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
901
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
902
903
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
            is added to the attention score of query i and key j.
904
905
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
906
907
908
909
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
910
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
911
912
913
914
915
916
917
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
918
    return FlashAttnVarlenQKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
919
920
921
922
923
924
925
        qkv,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
926
        alibi_slopes,
927
        deterministic,
Tri Dao's avatar
Tri Dao committed
928
        return_attn_probs,
929
        out,
Tri Dao's avatar
Tri Dao committed
930
    )
Tri Dao's avatar
Tri Dao committed
931
932


Tri Dao's avatar
Tri Dao committed
933
934
935
936
937
938
939
940
941
942
def flash_attn_varlen_kvpacked_func(
    q,
    kv,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
943
    window_size=(-1, -1),  # -1 means infinite context window
944
    alibi_slopes=None,
945
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
946
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
947
948
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
949
):
Tri Dao's avatar
Tri Dao committed
950
    """dropout_p should be set to 0.0 during evaluation
Tri Dao's avatar
Tri Dao committed
951
952
953
954
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
955
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
956
957
958
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

959
960
961
962
963
964
965
966
967
968
969
970
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
971
972
973
974
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
975
976
    Arguments:
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Tri Dao's avatar
Tri Dao committed
977
        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
Tri Dao's avatar
Tri Dao committed
978
979
980
981
982
983
984
985
986
987
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
988
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
989
990
991
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
992
993
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
994
995
996
997
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
Tri Dao's avatar
Tri Dao committed
998
        out: (total, nheads, headdim).
Tri Dao's avatar
Tri Dao committed
999
1000
1001
1002
1003
1004
1005
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1006
    return FlashAttnVarlenKVPackedFunc.apply(
Tri Dao's avatar
Tri Dao committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1016
        window_size,
1017
        alibi_slopes,
1018
        deterministic,
Tri Dao's avatar
Tri Dao committed
1019
        return_attn_probs,
1020
        out,
Tri Dao's avatar
Tri Dao committed
1021
    )
Tri Dao's avatar
Tri Dao committed
1022

1023

Tri Dao's avatar
Tri Dao committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1035
    window_size=(-1, -1),  # -1 means infinite context window
1036
    alibi_slopes=None,
1037
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
1038
    return_attn_probs=False,
1039
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
1040
1041
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1042
):
Tri Dao's avatar
Tri Dao committed
1043
1044
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1045
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
1046
1047
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1048

1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1061
1062
1063
1064
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1065
    Arguments:
Tri Dao's avatar
Tri Dao committed
1066
1067
1068
1069
1070
1071
1072
1073
1074
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
1075
1076
1077
1078
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1079
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1080
1081
1082
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
1083
1084
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
Tri Dao's avatar
Tri Dao committed
1097
    return FlashAttnVarlenFunc.apply(
Tri Dao's avatar
Tri Dao committed
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1108
        window_size,
1109
        alibi_slopes,
1110
        deterministic,
Tri Dao's avatar
Tri Dao committed
1111
        return_attn_probs,
1112
        block_table,
1113
        out,
Tri Dao's avatar
Tri Dao committed
1114
    )
Tri Dao's avatar
Tri Dao committed
1115
1116
1117
1118
1119
1120
1121
1122


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
1123
1124
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
1125
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1126
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1127
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
1128
1129
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
1130
    window_size=(-1, -1),  # -1 means infinite context window
1131
    rotary_interleaved=True,
1132
    alibi_slopes=None,
1133
    num_splits=0,
Antoni Baum's avatar
Antoni Baum committed
1134
1135
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
1147
1148
1149
1150
1151
1152
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1153
1154

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
1173
1174
1175
1176
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

1177
1178
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
1179
1180
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
1181
1182
1183
1184
1185
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1186
1187
1188
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1189
1190
1191
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
1192
1193
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
1194
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1195
1196
1197
1198
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
1199
1200
1201
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
1202
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1203
1204
1205
1206
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
1207
1208
1209
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.

    Return:
        out: (batch_size, seqlen, nheads, headdim).
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
1223
1224
1225
1226
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
1227
1228
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
1229
    block_table = maybe_contiguous(block_table)
Tri Dao's avatar
Tri Dao committed
1230
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1231
1232
1233
1234
1235
1236
1237
1238
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
1239
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
1240
        block_table,
1241
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
1242
        out,
1243
1244
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
1245
1246
        window_size[0],
        window_size[1],
1247
1248
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
1249
1250
    )
    return out