rocm.py 8.56 KB
Newer Older
1
import os
2
from typing import Optional
3
import torch
4
from text_generation_server.layers.attention.kv_cache import KVCache
5
from text_generation_server.utils.import_utils import SYSTEM
Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation_server.layers.attention import Seqlen
7
from text_generation_server.utils.log import log_master
8
9
10
11
from loguru import logger

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
12
13
14

_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
15
16
17
18

use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

19
20
21
22
23
24
25
26
27
28
29
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
    if use_rocm_custom_paged_attn:
        from vllm._custom_C import paged_attention_custom
except ImportError as e:
    log_master(
        logger.info,
        f"Custom Paged Attention not available. Complete error: {e}",
    )
    use_rocm_custom_paged_attn = False

30
31
32

def paged_attention(
    query: torch.Tensor,
33
    kv_cache: KVCache,
34
35
36
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
37
    seqlen: Seqlen,
38
    max_s: int,
39
    softcap: Optional[float] = None,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

58
59
60
    if softcap is not None:
        raise RuntimeError("Paged attention doesn't support softcapping")

61
    # value_cache => [num_blocks, num_heads, head_size, block_size]
62
    block_size = kv_cache.value.shape[3]
63
    num_seqs, num_heads, head_size = query.shape
64

65
    num_kv_heads = kv_cache.key.shape[1]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    gqa_ratio = num_heads // num_kv_heads
    use_custom = (
        use_rocm_custom_paged_attn
        and (query.dtype == torch.half or query.dtype == torch.bfloat16)
        and (head_size == 128 or head_size == 64)
        and (block_size == 16 or block_size == 32)
        and (gqa_ratio >= 1 and gqa_ratio <= 16)
        and max_s <= 32768
    )

    if not use_custom:
        _PARTITION_SIZE = _PARTITION_SIZE_V1V2
    else:
        _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM

81
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
82
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
83

84
85
    out = torch.empty_like(query)

86
87
88
89
90
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
91
    import vllm._custom_ops as ops
92

93
94
95
96
97
    use_v1 = (
        max_s <= 8192
        and (max_num_partitions == 1 or num_seqs * num_heads > 512)
        and not use_custom
    )
98
99
100
101
    if use_v1:
        ops.paged_attention_v1(
            out,
            query,
102
103
            kv_cache.key,
            kv_cache.value,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            kv_head_mapping,
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
            "auto",
            1.0,
        )
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=out.dtype,
            device=out.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=out.device,
        )
        max_logits = torch.empty_like(exp_sums)

129
130
131
132
133
134
135
        if not use_custom:
            ops.paged_attention_v2(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
136
137
                kv_cache.key,
                kv_cache.value,
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
        else:
            paged_attention_custom(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
155
156
                kv_cache.key,
                kv_cache.value,
157
158
159
160
161
162
163
164
165
166
                num_kv_heads,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
            )

167
    return out
168
169
170
171
172
173


if ENGINE != "triton":
    try:
        import flash_attn_2_cuda

174
175
176
177
        log_master(
            logger.info,
            "ROCm: using Flash Attention 2 Composable Kernel implementation.",
        )
Nicolas Patry's avatar
Nicolas Patry committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    except ImportError as e:
        if major >= 8:
            architecture_suffix = f"-{SYSTEM}"
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        elif is_sm75:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
        else:
            for idx in range(torch.cuda.device_count()):
                name = torch.cuda.get_device_name(idx)
                if "MI210" not in name and "MI250" not in name:
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
            raise ImportError(
                f"AMD GPU with ROCm capability {major} {minor} is not supported"
            ) from e


SUPPORTS_WINDOWING = False
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220


def attention(
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
    seqlen: Seqlen,
    block_tables: torch.Tensor,
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
    softcap: Optional[float] = None,
):
    if ENGINE == "ck":
221
222
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
223

224
225
226
227
        out = torch.empty_like(query)

        if softcap is None:
            softcap = 0.0
228

229
        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
230
        return flash_attn_2_cuda.varlen_fwd(
231
232
233
            query,
            key,
            value,
234
            out,
235
236
237
238
239
240
241
242
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            None,
            None,
            None,
            None,
            seqlen.max_q,
            seqlen.max_k,
243
244
245
246
            0.0,
            softmax_scale,
            False,
            causal,
247
248
249
            window_size_left,
            0,
            softcap,
250
251
            False,
            None,
252
        )[0]
253

254
255
256
    elif ENGINE == "triton":
        from .flash_attn_triton import triton_attention

257
258
259
        if softcap is not None:
            raise NotImplementedError("softcap is only available with CK flash attn")

260
        out = torch.empty_like(query)
261

262
        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
263
        output, _ = triton_attention(
264
265
266
            query,
            key,
            value,
267
            out,
268
269
270
271
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            seqlen.max_q,
            seqlen.max_k,
272
273
274
275
276
            causal,
            softmax_scale,
        )
        return output

277
278
279
    else:
        raise RuntimeError(f"Unknown attention engine {ENGINE}")

280
281
282
283
284
285

__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]