rocm.py 9.48 KB
Newer Older
1
import os
2
from typing import Optional
3
4
import torch
from text_generation_server.utils.import_utils import SYSTEM
5
from text_generation_server.models.globals import ATTENTION
Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation_server.layers.attention import Seqlen
7
from text_generation_server.utils.log import log_master
8
9
10
11
from loguru import logger

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
12
13
14

_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
15
16
17
18

use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

19
20
PREFILL_IN_KV_CACHE = False

21
22
23
24
25
26
27
28
29
30
31
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
    if use_rocm_custom_paged_attn:
        from vllm._custom_C import paged_attention_custom
except ImportError as e:
    log_master(
        logger.info,
        f"Custom Paged Attention not available. Complete error: {e}",
    )
    use_rocm_custom_paged_attn = False

32
try:
33
    import vllm._custom_ops as ops
34
35
36
37
38
39
40
41
42
43
44
45
46
except Exception as e:
    raise ImportError(
        f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
    )


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
47
    if ATTENTION == "flashdecoding":
48
49
50
51
        shape = key_cache.shape
        key_cache.view(-1, shape[-2], shape[-1])[slots] = key
        value_cache.view(-1, shape[-2], shape[-1])[slots] = value
    else:
52
        ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
53
54
55
56
57
58
59
60
61


def paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
62
    seqlen: Seqlen,
63
    max_s: int,
64
    softcap: Optional[float] = None,
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

83
84
85
    if softcap is not None:
        raise RuntimeError("Paged attention doesn't support softcapping")

86
87
88
    # value_cache => [num_blocks, num_heads, head_size, block_size]
    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    num_kv_heads = key_cache.shape[1]
    gqa_ratio = num_heads // num_kv_heads
    use_custom = (
        use_rocm_custom_paged_attn
        and (query.dtype == torch.half or query.dtype == torch.bfloat16)
        and (head_size == 128 or head_size == 64)
        and (block_size == 16 or block_size == 32)
        and (gqa_ratio >= 1 and gqa_ratio <= 16)
        and max_s <= 32768
    )

    if not use_custom:
        _PARTITION_SIZE = _PARTITION_SIZE_V1V2
    else:
        _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM

106
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
107
    input_lengths = seqlen.input_lengths
108

109
110
    out = torch.empty_like(query)

111
112
113
114
115
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
116
    import vllm._custom_ops as ops
117

118
119
120
121
122
    use_v1 = (
        max_s <= 8192
        and (max_num_partitions == 1 or num_seqs * num_heads > 512)
        and not use_custom
    )
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    if use_v1:
        ops.paged_attention_v1(
            out,
            query,
            key_cache,
            value_cache,
            kv_head_mapping,
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
            "auto",
            1.0,
        )
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=out.dtype,
            device=out.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=out.device,
        )
        max_logits = torch.empty_like(exp_sums)

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if not use_custom:
            ops.paged_attention_v2(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
        else:
            paged_attention_custom(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
            )

192
    return out
193
194
195
196
197
198


if ENGINE != "triton":
    try:
        import flash_attn_2_cuda

199
200
201
202
        log_master(
            logger.info,
            "ROCm: using Flash Attention 2 Composable Kernel implementation.",
        )
Nicolas Patry's avatar
Nicolas Patry committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    except ImportError as e:
        if major >= 8:
            architecture_suffix = f"-{SYSTEM}"
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        elif is_sm75:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
        else:
            for idx in range(torch.cuda.device_count()):
                name = torch.cuda.get_device_name(idx)
                if "MI210" not in name and "MI250" not in name:
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
            raise ImportError(
                f"AMD GPU with ROCm capability {major} {minor} is not supported"
            ) from e


SUPPORTS_WINDOWING = False
230
231
232
233
if ENGINE == "ck":

    def attention(
        q,
234
235
236
237
238
239
240
241
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        seqlen: Seqlen,
        block_tables: torch.Tensor,
        softmax_scale: float,
        window_size_left: int = -1,
        causal: bool = True,
        softcap: float = 0.0,
242
243
244
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
245

246
247
        out = torch.empty_like(q)

248
        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
249
250
        return flash_attn_2_cuda.varlen_fwd(
            q,
251
252
            key_cache,
            value_cache,
253
            out,
254
255
256
257
258
259
260
261
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            None,
            None,
            None,
            None,
            seqlen.max_q,
            seqlen.max_k,
262
263
264
265
            0.0,
            softmax_scale,
            False,
            causal,
266
267
268
            window_size_left,
            0,
            softcap,
269
270
            False,
            None,
271
        )[0]
272
273
274
275
276
277

elif ENGINE == "triton":
    from .flash_attn_triton import triton_attention

    def attention(
        q,
278
279
280
281
282
283
284
285
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        seqlen: Seqlen,
        block_tables: torch.Tensor,
        softmax_scale: float,
        window_size_left: int = -1,
        causal: bool = True,
        softcap: Optional[float] = None,
286
    ):
287
288
289
        if softcap is not None:
            raise NotImplementedError("softcap is only available with CK flash attn")

290
291
        out = torch.empty_like(q)

292
        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
293
294
        output, _ = triton_attention(
            q,
295
296
            key_cache,
            value_cache,
297
            out,
298
299
300
301
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            seqlen.max_q,
            seqlen.max_k,
302
303
304
305
306
307
            causal,
            softmax_scale,
        )
        return output

else:
Nicolas Patry's avatar
Nicolas Patry committed
308
    raise RuntimeError(f"Unknown attention engine {ENGINE}")
309
310
311
312
313
314
315
316

__all__ = [
    "PREFILL_IN_KV_CACHE",
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
    "reshape_and_cache",
]