rocm.py 8.79 KB
Newer Older
jixx's avatar
init  
jixx committed
1
import os
jixx's avatar
jixx committed
2
from typing import Optional
jixx's avatar
init  
jixx committed
3
import torch
jixx's avatar
jixx committed
4
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
jixx's avatar
init  
jixx committed
5
6
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen
jixx's avatar
jixx committed
7
from text_generation_server.utils.log import log_master
jixx's avatar
init  
jixx committed
8
9
10
11
from loguru import logger

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
jixx's avatar
jixx committed
12
13
14

_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
jixx's avatar
init  
jixx committed
15
16
17
18

use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

jixx's avatar
jixx committed
19
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
jixx's avatar
init  
jixx committed
20
try:
jixx's avatar
jixx committed
21
22
23
24
25
26
27
    if use_rocm_custom_paged_attn:
        # from vllm._custom_ops import paged_attention_custom
        from vllm import _custom_ops
except ImportError as e:
    log_master(
        logger.info,
        f"Custom Paged Attention not available. Complete error: {e}",
jixx's avatar
init  
jixx committed
28
    )
jixx's avatar
jixx committed
29
    use_rocm_custom_paged_attn = False
jixx's avatar
init  
jixx committed
30
31
32
33


def paged_attention(
    query: torch.Tensor,
jixx's avatar
jixx committed
34
    kv_cache: KVCache,
jixx's avatar
init  
jixx committed
35
36
37
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
jixx's avatar
jixx committed
38
    seqlen: Seqlen,
jixx's avatar
init  
jixx committed
39
    max_s: int,
jixx's avatar
jixx committed
40
41
42
    *,
    kv_scales: KVScales,
    softcap: Optional[float] = None,
jixx's avatar
init  
jixx committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

jixx's avatar
jixx committed
61
62
63
    if softcap is not None:
        raise RuntimeError("Paged attention doesn't support softcapping")

jixx's avatar
init  
jixx committed
64
    # value_cache => [num_blocks, num_heads, head_size, block_size]
jixx's avatar
jixx committed
65
    block_size = kv_cache.value.shape[3]
jixx's avatar
init  
jixx committed
66
    num_seqs, num_heads, head_size = query.shape
jixx's avatar
jixx committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    num_kv_heads = kv_cache.key.shape[1]

    # gqa_ratio = num_heads // num_kv_heads
    # use_custom = (
    #     use_rocm_custom_paged_attn
    #     and (query.dtype == torch.half or query.dtype == torch.bfloat16)
    #     and (head_size == 128 or head_size == 64)
    #     and (block_size == 16 or block_size == 32)
    #     and (gqa_ratio >= 1 and gqa_ratio <= 16)
    #     and max_s <= 32768
    # )

    # if not use_custom:
    #     _PARTITION_SIZE = _PARTITION_SIZE_V1V2
    # else:
    #     _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM

    _PARTITION_SIZE = 512

jixx's avatar
init  
jixx committed
87
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
jixx's avatar
jixx committed
88
89
90
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths

    out = torch.empty_like(query)
jixx's avatar
init  
jixx committed
91
92
93
94
95
96

    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
jixx's avatar
jixx committed
97
98
99
100
101
102
103
104
    # import vllm._custom_ops as ops
    from vllm import _custom_ops

    use_v1 = (
        max_s <= 8192
        and (max_num_partitions == 1 or num_seqs * num_heads > 512)
    )

jixx's avatar
init  
jixx committed
105
106
107
108
    if use_v1:
        _custom_ops.paged_attention_v1(
            out,
            query,
jixx's avatar
jixx committed
109
110
111
            kv_cache.key,
            kv_cache.value,
            # kv_head_mapping,
jixx's avatar
init  
jixx committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            kv_head_mapping.shape[0],
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
            "auto",
            1.0,
        )
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=out.dtype,
            device=out.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=out.device,
        )
        max_logits = torch.empty_like(exp_sums)

jixx's avatar
jixx committed
137
        # if not use_custom:
jixx's avatar
init  
jixx committed
138
139
140
141
142
143
        _custom_ops.paged_attention_v2(
            out,
            exp_sums,
            max_logits,
            tmp_output,
            query,
jixx's avatar
jixx committed
144
145
            kv_cache.key,
            kv_cache.value,
jixx's avatar
init  
jixx committed
146
147
148
149
150
151
152
153
154
155
156
            # kv_head_mapping,
            kv_head_mapping.shape[0],
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
            "auto",
            1.0,
        )
jixx's avatar
jixx committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

        # else:
        #     paged_attention_custom(
        #         out,
        #         exp_sums,
        #         max_logits,
        #         tmp_output,
        #         query,
        #         kv_cache.key,
        #         kv_cache.value,
        #         num_kv_heads,
        #         softmax_scale,
        #         block_tables,
        #         input_lengths,
        #         block_size,
        #         max_s,
        #         None,
        #         "auto",
        #     )

jixx's avatar
init  
jixx committed
177
178
179
180
181
182
183
    return out


if ENGINE != "triton":
    try:
        import flash_attn_2_cuda

jixx's avatar
jixx committed
184
185
186
187
        log_master(
            logger.info,
            "ROCm: using Flash Attention 2 Composable Kernel implementation.",
        )
jixx's avatar
init  
jixx committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    except ImportError as e:
        if major >= 8:
            architecture_suffix = f"-{SYSTEM}"
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        elif is_sm75:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
        else:
            for idx in range(torch.cuda.device_count()):
                name = torch.cuda.get_device_name(idx)
                if "MI210" not in name and "MI250" not in name:
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
            raise ImportError(
                f"AMD GPU with ROCm capability {major} {minor} is not supported"
            ) from e


SUPPORTS_WINDOWING = False
jixx's avatar
jixx committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231


def attention(
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
    kv_scales: KVScales,
    seqlen: Seqlen,
    block_tables: torch.Tensor,
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
    softcap: Optional[float] = None,
):
    if ENGINE == "ck":
jixx's avatar
init  
jixx committed
232
233
234
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")

jixx's avatar
jixx committed
235
        out = torch.empty_like(query)
jixx's avatar
init  
jixx committed
236

jixx's avatar
jixx committed
237
238
        if softcap is None:
            softcap = 0.0
jixx's avatar
init  
jixx committed
239

jixx's avatar
jixx committed
240
        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
jixx's avatar
init  
jixx committed
241
        return flash_attn_2_cuda.varlen_fwd(
jixx's avatar
jixx committed
242
243
244
            query,
            key,
            value,
jixx's avatar
init  
jixx committed
245
            out,
jixx's avatar
jixx committed
246
247
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
jixx's avatar
init  
jixx committed
248
249
250
251
            None,
            None,
            None,
            None,
jixx's avatar
jixx committed
252
253
            seqlen.max_q,
            seqlen.max_k,
jixx's avatar
init  
jixx committed
254
255
256
257
258
259
260
261
262
263
264
            0.0,
            softmax_scale,
            False,
            causal,
            window_size_left,
            0,
            softcap,
            False,
            None,
        )[0]

jixx's avatar
jixx committed
265
266
267
268
269
270
271
    elif ENGINE == "triton":
        from .flash_attn_triton import triton_attention

        if softcap is not None:
            raise NotImplementedError("softcap is only available with CK flash attn")

        out = torch.empty_like(query)
jixx's avatar
init  
jixx committed
272
273
274

        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
        output, _ = triton_attention(
jixx's avatar
jixx committed
275
276
277
            query,
            key,
            value,
jixx's avatar
init  
jixx committed
278
            out,
jixx's avatar
jixx committed
279
280
281
282
            seqlen.cu_seqlen_q,
            seqlen.cu_seqlen_q,
            seqlen.max_q,
            seqlen.max_k,
jixx's avatar
init  
jixx committed
283
284
285
286
287
            causal,
            softmax_scale,
        )
        return output

jixx's avatar
jixx committed
288
289
290
291
292
293
294
295
296
    else:
        raise RuntimeError(f"Unknown attention engine {ENGINE}")


__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]