flash_attn.py 8.25 KB
Newer Older
1
2
3
4
import os
import torch

from loguru import logger
fxmarty's avatar
fxmarty committed
5
import math
6

Nicolas Patry's avatar
Nicolas Patry committed
7
from text_generation_server.utils.import_utils import SYSTEM
Wang, Yi's avatar
Wang, Yi committed
8
9
10

if SYSTEM != "xpu":
    from text_generation_server.utils.flash_attn_triton import triton_attention
fxmarty's avatar
fxmarty committed
11

12
13
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")
fxmarty's avatar
fxmarty committed
14
HAS_FLASH_ATTN = False
15
16
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
fxmarty's avatar
fxmarty committed
17
18
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
19

Nicolas Patry's avatar
Nicolas Patry committed
20
21

if SYSTEM in {"cuda", "rocm"}:
22
23
    if not torch.cuda.is_available():
        raise ImportError("CUDA is not available")
24

25
26
27
28
    major, minor = torch.cuda.get_device_capability()
    is_sm75 = major == 7 and minor == 5
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0
fxmarty's avatar
fxmarty committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    is_sm94 = major == 9 and minor == 4

    if SYSTEM == "rocm":
        if (
            os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
            or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
        ):
            ROCM_USE_FLASH_ATTN_V2_TRITON = True
            logger.info("ROCm: using Flash Attention 2 Triton implementation.")
        else:
            ROCM_USE_FLASH_ATTN_V2_CK = True
            logger.info(
                "ROCm: using Flash Attention 2 Composable Kernel implementation."
            )
43

44
    try:
45
46
47
        try:
            import flash_attn_2_cuda
        except ImportError:
Nicolas Patry's avatar
Nicolas Patry committed
48
            architecture_suffix = f"-{SYSTEM}"
49
50
51
52
53
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
fxmarty's avatar
fxmarty committed
54
        if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
55
56
57
58
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
fxmarty's avatar
fxmarty committed
59
60
61
62
63
        elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
            raise ImportError(
                f"AMD GPU with compute capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
Nicolas Patry's avatar
Nicolas Patry committed
64
65
        HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
        HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
66
67
68
69
70
71
72
73
74
    except ImportError as e:
        try:
            import flash_attn_cuda
        except ImportError:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
75

Nicolas Patry's avatar
Nicolas Patry committed
76
        if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
77
78
79
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            ) from e
Nicolas Patry's avatar
Nicolas Patry committed
80
        elif SYSTEM == "rocm":
81
82
83
84
85
86
87
            for idx in range(torch.cuda.device_count()):
                if "MI210" not in torch.cuda.get_device_name(
                    idx
                ) and "MI250" not in torch.cuda.get_device_name(idx):
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
fxmarty's avatar
fxmarty committed
88

89
90
        logger.warning(f"Unable to use Flash Attention V2: {e}")
        HAS_FLASH_ATTN = True
91

Wang, Yi's avatar
Wang, Yi committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
if SYSTEM == "xpu":
    import intel_extension_for_pytorch as ipex

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")

        if window_size_left != -1:
            raise ValueError(
                f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )
        return ipex.llm.functional.varlen_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )
128

Wang, Yi's avatar
Wang, Yi committed
129
elif HAS_FLASH_ATTN_V2_CUDA:
Nicolas Patry's avatar
Nicolas Patry committed
130
131
132
133
134
135
136
137
138
139

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
drbh's avatar
drbh committed
140
        causal=True,
Nicolas Patry's avatar
Nicolas Patry committed
141
142
143
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
144
145
146
147
148
149
150
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
151
152
153
            None,
            None,
            None,
154
155
156
157
158
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
drbh's avatar
drbh committed
159
            causal,
160
161
            window_size_left,
            0,
162
163
164
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
165

fxmarty's avatar
fxmarty committed
166
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169
170
171
172
173
174
175
176

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
fxmarty's avatar
fxmarty committed
177
        causal=True,
Nicolas Patry's avatar
Nicolas Patry committed
178
179
180
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
fxmarty's avatar
fxmarty committed
181
        if window_size_left != -1:
OlivierDehaene's avatar
OlivierDehaene committed
182
183
184
185
            raise ValueError(
                f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )

fxmarty's avatar
fxmarty committed
186
187
188
189
190
191
192
193
194
195
196
197
198
        # RoCm flash API does not take the window_size_left and window_size_right arguments.
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
fxmarty's avatar
fxmarty committed
199
            causal,
fxmarty's avatar
fxmarty committed
200
201
202
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
203

fxmarty's avatar
fxmarty committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
        causal=True,
    ):
        output, _ = triton_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            causal,
            softmax_scale,
        )
        return output

Nicolas Patry's avatar
Nicolas Patry committed
231
232
233
234
235
236
237
238
239
240
241
242
elif HAS_FLASH_ATTN:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
243
        if window_size_left != -1:
244
245
246
247
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        # Flash attention v1 requires q, k and v to have the same number of heads
        if k.shape[1] != q.shape[1]:
            # MQA expand
            if k.shape[1] == 1:
                k = k.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = k.shape
                k = (
                    k.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // k.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )
        if v.shape[1] != q.shape[1]:
            # MQA expand
            if v.shape[1] == 1:
                v = v.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = v.shape
                v = (
                    v.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // v.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )

        return flash_attn_cuda.fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            0,
            None,
        )

Nicolas Patry's avatar
Nicolas Patry committed
292
else:
293
    raise NotImplementedError("flash attention is not installed")