cuda.py 10.9 KB
Newer Older
1
import torch
2
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
3
from text_generation_server.utils.import_utils import SYSTEM
4
from text_generation_server.models.globals import (
5
    ATTENTION,
6
7
    BLOCK_SIZE,
)
8
from text_generation_server.layers.attention import Seqlen
9
from typing import Optional
10

11

12
13
14
15
16
17
18
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512


def paged_attention(
    query: torch.Tensor,
19
    kv_cache: KVCache,
20
21
22
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
23
    seqlen: Seqlen,
24
    max_s: int,
25
26
    *,
    kv_scales: KVScales,
27
    softcap: Optional[float] = None,
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

    # value_cache => [num_blocks, num_heads, head_size, block_size]
47
48
    # block_size = value_cache.shape[3]
    block_size = BLOCK_SIZE
49
50
51
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE

52
53
    can_scale = kv_cache.can_scale(kv_scales)

54
55
56
57
58
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
59
    if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
60
        from text_generation_server.layers.attention.flashinfer import decode_state
61
62

        return decode_state.get().forward(
63
            # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
64
            query.contiguous(),
65
            paged_kv_cache=(kv_cache.key, kv_cache.value),
66
67
            logits_soft_cap=softcap,
            sm_scale=softmax_scale,
68
69
            k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
            v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
70
        )
71
    elif ATTENTION == "flashdecoding":
72
73
74
        max_q = 1
        max_k = max_s
        import flash_attn_2_cuda
75

76
77
78
79
80
        # TODO fixme when flash contains the fix.
        # Number of splits is not correctly handled
        # by the current path
        # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
        # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
81
82
        if softcap is None:
            softcap = 0.0
83
        out = flash_attn_2_cuda.varlen_fwd(
84
            query,
85
86
            kv_cache.key,
            kv_cache.value,
87
88
89
            None,
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_k,
90
            None,  # pad_k
91
            None,
92
93
            block_tables,
            None,
94
95
96
97
98
99
100
101
            max_q,
            max_k,
            0.0,  # dropout
            softmax_scale,
            False,  # zero_tensors
            True,  # causal
            -1,  # Window_left
            -1,  # Window right
102
            softcap,
103
104
            False,  # return softmax
            None,  # generator
105
        )
106
        return out[0]
107
    else:
108
109
        if softcap is not None:
            raise RuntimeError("Paged attention doesn't support softcapping")
110
        input_lengths = seqlen.input_lengths + seqlen.cache_lengths
111
        from vllm._C import ops
112

113
114
        out = torch.empty_like(query)

115
116
        use_v1 = max_s <= 8192 and (
            max_num_partitions == 1 or num_seqs * num_heads > 512
117
        )
118
119
120
121
        if use_v1:
            ops.paged_attention_v1(
                out,
                query,
122
123
                kv_cache.key,
                kv_cache.value,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=out.dtype,
                device=out.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=out.device,
            )
            max_logits = torch.empty_like(exp_sums)

            ops.paged_attention_v2(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
155
156
                kv_cache.key,
                kv_cache.value,
157
158
159
160
161
162
163
164
165
166
167
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
    return out
168
169
170


try:
171
172
173
174
    is_ampere_or_newer = major >= 8 and minor >= 0
    if not is_ampere_or_newer:
        raise ImportError("FlashAttention only supports Ampere GPUs or newer.")

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    import flash_attn_2_cuda

    V2 = True
except ImportError:
    try:
        import flash_attn_cuda

        V2 = False
    except ImportError as e:
        if major >= 8:
            architecture_suffix = f"-{SYSTEM}"
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        elif is_sm75:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
        else:
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            ) from e


203
204
205
if ATTENTION == "flashdecoding" and not V2:
    raise ValueError("Flash decoding requires Flash Attention V2")

206
SUPPORTS_WINDOWING = V2
207

208
209
210
211
212
213
214

def attention(
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
215
    kv_scales: KVScales,
216
217
218
219
220
221
222
    seqlen: Seqlen,
    block_tables: torch.Tensor,
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
    softcap: Optional[float] = None,
):
223
224
    can_scale = kv_cache.can_scale(kv_scales)

225
    if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
226
227
228
        from text_generation_server.layers.attention.flashinfer import (
            prefill_with_paged_kv_state,
        )
229

230
231
232
        if softcap is None:
            softcap = 0.0

Nicolas Patry's avatar
Nicolas Patry committed
233
        return prefill_with_paged_kv_state.get().forward(
234
            # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
235
            query.contiguous(),
236
            causal=causal,
237
            paged_kv_cache=(kv_cache.key, kv_cache.value),
238
239
            logits_soft_cap=softcap,
            sm_scale=softmax_scale,
240
            window_left=window_size_left,
241
242
            k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
            v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
243
244
        )

245
246
247
248
249
250
    # If we are using flashdecoding or paged, we always use flash-attn for
    # the prefill. We have to branch on whether we use flash-attn v1 or v2.
    elif V2:
        out = torch.empty_like(query)
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
251

252
253
        if softcap is None:
            softcap = 0.0
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        return flash_attn_2_cuda.varlen_fwd(
            query,
            # flashdecoding: pass the KV caches, paged: pass the KV.
            kv_cache.key if ATTENTION == "flashdecoding" else key,
            kv_cache.value if ATTENTION == "flashdecoding" else value,
            out,
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_k,
            None,
            None,
            block_tables if ATTENTION == "flashdecoding" else None,
            None,
            seqlen.max_q,
            seqlen.max_k,
            0.0,
270
            softmax_scale,
271
272
273
274
275
276
277
278
            False,
            causal,
            window_size_left,
            0,
            softcap,
            False,
            None,
        )[0]
279
280

    else:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        if window_size_left != -1:
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )
        if softcap is not None:
            raise NotImplementedError("softcap is not available in flash attn v1")

        # Flash attention v1 requires q, k and v to have the same number of heads
        if key.shape[1] != query.shape[1]:
            # MQA expand
            if key.shape[1] == 1:
                key = key.expand(-1, query.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = key.shape
                key = (
                    key.unsqueeze(2)
                    .expand(-1, -1, query.shape[1] // key.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
300
                )
301
302
303
304
305
306
307
308
309
310
311
        if value.shape[1] != query.shape[1]:
            # MQA expand
            if value.shape[1] == 1:
                value = value.expand(-1, query.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = value.shape
                value = (
                    value.unsqueeze(2)
                    .expand(-1, -1, query.shape[1] // value.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
312
313
                )

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        out = torch.empty_like(query)
        flash_attn_cuda.fwd(
            query,
            key,
            value,
            out,
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            seqlen.max_q,
            seqlen.max_k,
            0.0,
            softmax_scale,
            False,
            causal,
            False,
            0,
            None,
        )
        return out
333

334
335
336
337
338
339

__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]