flash_attn_interface.py 17.1 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
from typing import Optional, Union

Tri Dao's avatar
Tri Dao committed
5
6
import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
7

8
9
# isort: off
# We need to import the CUDA kernels after importing torch
10
11
# Use relative import to support build-from-source installation in vLLM
from . import vllm_flash_attn_c # noqa: F401
Tri Dao's avatar
Tri Dao committed
12

13
14
# isort: on

15
16
def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Tri Dao's avatar
Tri Dao committed
17

Tri Dao's avatar
Tri Dao committed
18
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
    # This should match the block sizes in the CUDA kernel
    assert head_dim <= 256
    major, minor = torch.cuda.get_device_capability(device)
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
    is_sm80 = major == 8 and minor == 0
    is_sm90 = major == 9 and minor == 0
    if head_dim <= 32:
Tri Dao's avatar
Tri Dao committed
26
        return 128
Tri Dao's avatar
Tri Dao committed
27
    if head_dim <= 64:
Tri Dao's avatar
Tri Dao committed
28
        return 128 if not is_dropout else 64
Tri Dao's avatar
Tri Dao committed
29
    elif head_dim <= 96:
Tri Dao's avatar
Tri Dao committed
30
        return 64
Tri Dao's avatar
Tri Dao committed
31
32
    elif head_dim <= 128:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
33
            return 64 if (not is_dropout and is_causal) else 32
Tri Dao's avatar
Tri Dao committed
34
        else:
Tri Dao's avatar
Tri Dao committed
35
            return 64 if not is_dropout else 32
Tri Dao's avatar
Tri Dao committed
36
37
    elif head_dim <= 160:
        if is_sm8x:
Tri Dao's avatar
Tri Dao committed
38
            return 64
Tri Dao's avatar
Tri Dao committed
39
        else:
Tri Dao's avatar
Tri Dao committed
40
            return 32
Tri Dao's avatar
Tri Dao committed
41
    elif head_dim <= 192:
Tri Dao's avatar
Tri Dao committed
42
        return 64
Tri Dao's avatar
Tri Dao committed
43
    elif head_dim <= 224:
Tri Dao's avatar
Tri Dao committed
44
        return 64
Tri Dao's avatar
Tri Dao committed
45
    elif head_dim <= 256:
Tri Dao's avatar
Tri Dao committed
46
        return 64
Tri Dao's avatar
Tri Dao committed
47
48


Tri Dao's avatar
Tri Dao committed
49
def _flash_attn_forward(
50
    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
Tri Dao's avatar
Tri Dao committed
51
):
Tri Dao's avatar
Tri Dao committed
52
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
53
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd(
Tri Dao's avatar
Tri Dao committed
54
55
56
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
57
        out,
58
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
        dropout_p,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
64
        softcap,
Tri Dao's avatar
Tri Dao committed
65
66
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
67
    )
68
    return out, softmax_lse
Tri Dao's avatar
Tri Dao committed
69
70


Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
77
78
79
80
81
def _flash_attn_varlen_forward(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p,
    softmax_scale,
    causal,
Tri Dao's avatar
Tri Dao committed
82
    window_size,
83
    softcap,
84
    alibi_slopes,
Tri Dao's avatar
Tri Dao committed
85
    return_softmax,
86
    block_table,
Antoni Baum's avatar
Antoni Baum committed
87
88
    *,
    out=None
Tri Dao's avatar
Tri Dao committed
89
):
Tri Dao's avatar
Tri Dao committed
90
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
91
    out, softmax_lse = torch.ops.vllm_flash_attn_c.varlen_fwd(
Tri Dao's avatar
Tri Dao committed
92
93
94
        q,
        k,
        v,
Antoni Baum's avatar
Antoni Baum committed
95
        out,
Tri Dao's avatar
Tri Dao committed
96
97
        cu_seqlens_q,
        cu_seqlens_k,
98
        None,
99
        block_table,
100
        alibi_slopes,
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
        False,
        causal,
Tri Dao's avatar
Tri Dao committed
107
108
        window_size[0],
        window_size[1],
109
        softcap,
Tri Dao's avatar
Tri Dao committed
110
111
        return_softmax,
        None,
Tri Dao's avatar
Tri Dao committed
112
    )
113
    return out, softmax_lse
Tri Dao's avatar
Tri Dao committed
114
115


Tri Dao's avatar
Tri Dao committed
116
def flash_attn_func(
Tri Dao's avatar
Tri Dao committed
117
118
119
120
121
122
123
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
124
    softcap=0.0, # 0.0 means deactivated
125
    alibi_slopes=None,
126
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
127
    return_attn_probs=False,
Antoni Baum's avatar
Antoni Baum committed
128
    *,
129
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
130
    out=None,
Tri Dao's avatar
Tri Dao committed
131
):
Tri Dao's avatar
Tri Dao committed
132
133
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
134
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
135
136
137
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

138
139
140
141
142
143
144
145
146
147
148
149
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
150
151
152
153
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
161
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
162
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
163
164
165
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
166
167
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
173
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
Tri Dao's avatar
Tri Dao committed
174
175
176
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
    """
177
178
179
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = _flash_attn_forward(
180
181
182
183
184
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
185
186
187
188
189
190
        causal=causal,
        window_size=window_size,
        softcap=softcap,
        alibi_slopes=alibi_slopes,
        return_softmax=return_attn_probs and dropout_p > 0,
        out=out,
Tri Dao's avatar
Tri Dao committed
191
    )
192
    return (out, softmax_lse) if return_softmax_lse else out
Tri Dao's avatar
Tri Dao committed
193

194

Tri Dao's avatar
Tri Dao committed
195
196
197
198
199
200
201
202
203
204
205
def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
206
    window_size=(-1, -1),  # -1 means infinite context window
207
    softcap=0.0, # 0.0 means deactivated
208
    alibi_slopes=None,
209
    deterministic=False,
Tri Dao's avatar
Tri Dao committed
210
    return_attn_probs=False,
211
    block_table=None,
Antoni Baum's avatar
Antoni Baum committed
212
    *,
213
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
214
    out=None,
Tri Dao's avatar
Tri Dao committed
215
):
Tri Dao's avatar
Tri Dao committed
216
217
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
218
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
Tri Dao's avatar
Tri Dao committed
219
220
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
221

222
223
224
225
226
227
228
229
230
231
232
233
    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
234
235
236
237
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

238
    Arguments:
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
244
245
246
247
        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
248
249
250
251
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
252
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
253
        softcap: float. Anything > 0 activates softcapping attention.
254
255
256
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
257
258
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
259
260
261
262
263
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (total, nheads, headdim).
264
        softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
265
266
267
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
    """
268
269
270
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = _flash_attn_varlen_forward(
Tri Dao's avatar
Tri Dao committed
271
272
273
274
275
276
277
278
279
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale,
280
281
282
283
284
285
286
        causal=causal,
        window_size=window_size,
        softcap=softcap,
        alibi_slopes=alibi_slopes,
        return_softmax=return_attn_probs and dropout_p > 0,
        block_table=block_table,
        out=out,
Tri Dao's avatar
Tri Dao committed
287
    )
288
    return (out, softmax_lse) if return_softmax_lse else out
Tri Dao's avatar
Tri Dao committed
289
290
291
292
293
294
295
296


def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
297
298
    rotary_cos=None,
    rotary_sin=None,
Tri Dao's avatar
Tri Dao committed
299
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
300
    cache_batch_idx: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
301
    block_table: Optional[torch.Tensor] = None,
Tri Dao's avatar
Tri Dao committed
302
303
    softmax_scale=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
304
    window_size=(-1, -1),  # -1 means infinite context window
305
    softcap=0.0, # 0.0 means deactivated
306
    rotary_interleaved=True,
307
    alibi_slopes=None,
308
    num_splits=0,
309
    return_softmax_lse=False,
Antoni Baum's avatar
Antoni Baum committed
310
311
    *,
    out=None,
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
317
318
319
320
321
322
):
    """
    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
    the previous step, and update them with the new keys/values from the current step, and do
    attention with the updated cache, all in 1 kernel.

    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Tri Dao's avatar
Tri Dao committed
323
324
325
326
327
328
    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
329
330

    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

Tri Dao's avatar
Tri Dao committed
349
350
351
352
    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

353
354
    Note: Does not support backward pass.

Tri Dao's avatar
Tri Dao committed
355
356
    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
            page_block_size must be a multiple of 256.
        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
362
363
364
        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
            k with k_cache, starting at the indices specified by cache_seqlens.
        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
365
366
367
        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
Tri Dao's avatar
Tri Dao committed
368
369
        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
            KV cache.
Tao He's avatar
Tao He committed
370
        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
371
372
373
374
        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
            If the indices are not distinct, and k and v are provided, the values updated in the cache
                 might come from any of the duplicate indices.
Tri Dao's avatar
Tri Dao committed
375
376
377
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Tri Dao's avatar
Tri Dao committed
378
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
379
        softcap: float. Anything > 0 activates softcapping attention.
380
381
382
383
        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
            (i.e. GPT-NeoX style).
384
385
386
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
Tri Dao's avatar
Tri Dao committed
387
388
389
390
        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
           to automatically determine the number of splits.
           Don't change this unless you know what you are doing.
391
        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Tri Dao's avatar
Tri Dao committed
392
393
394

    Return:
        out: (batch_size, seqlen, nheads, headdim).
395
396
397
        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
    """
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
404
405
406
407
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
408
409
        cache_seqlens = maybe_contiguous(cache_seqlens)
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
Tri Dao's avatar
Tri Dao committed
410
    block_table = maybe_contiguous(block_table)
411
    out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
412
413
414
415
416
417
418
419
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
420
        cache_batch_idx,
Tri Dao's avatar
Tri Dao committed
421
        block_table,
422
        alibi_slopes,
Antoni Baum's avatar
Antoni Baum committed
423
        out,
424
425
        softmax_scale,
        causal,
Tri Dao's avatar
Tri Dao committed
426
427
        window_size[0],
        window_size[1],
428
        softcap,
429
430
        rotary_interleaved,
        num_splits,
Tri Dao's avatar
Tri Dao committed
431
    )
432
    return (out, softmax_lse) if return_softmax_lse else out