flash_attn.py 8.22 KB
Newer Older
1
2
3
4
import os
import torch

from loguru import logger
fxmarty's avatar
fxmarty committed
5
import math
6

Nicolas Patry's avatar
Nicolas Patry committed
7
from text_generation_server.utils.import_utils import SYSTEM
fxmarty's avatar
fxmarty committed
8
from text_generation_server.utils.flash_attn_triton import triton_attention
fxmarty's avatar
fxmarty committed
9

10
11
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")
fxmarty's avatar
fxmarty committed
12
HAS_FLASH_ATTN = False
13
14
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
fxmarty's avatar
fxmarty committed
15
16
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
17

Nicolas Patry's avatar
Nicolas Patry committed
18
if SYSTEM == "xpu":
19
    import intel_extension_for_pytorch as ipex
20

Nicolas Patry's avatar
Nicolas Patry committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")

        if window_size_left != -1:
            raise ValueError(
                f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )
        return ipex.llm.functional.varlen_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )


if SYSTEM in {"cuda", "rocm"}:
57
58
    if not torch.cuda.is_available():
        raise ImportError("CUDA is not available")
59

60
61
62
63
    major, minor = torch.cuda.get_device_capability()
    is_sm75 = major == 7 and minor == 5
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0
fxmarty's avatar
fxmarty committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    is_sm94 = major == 9 and minor == 4

    if SYSTEM == "rocm":
        if (
            os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
            or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
        ):
            ROCM_USE_FLASH_ATTN_V2_TRITON = True
            logger.info("ROCm: using Flash Attention 2 Triton implementation.")
        else:
            ROCM_USE_FLASH_ATTN_V2_CK = True
            logger.info(
                "ROCm: using Flash Attention 2 Composable Kernel implementation."
            )
78

79
    try:
80
81
82
        try:
            import flash_attn_2_cuda
        except ImportError:
Nicolas Patry's avatar
Nicolas Patry committed
83
            architecture_suffix = f"-{SYSTEM}"
84
85
86
87
88
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
fxmarty's avatar
fxmarty committed
89
        if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
90
91
92
93
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
fxmarty's avatar
fxmarty committed
94
95
96
97
98
        elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
            raise ImportError(
                f"AMD GPU with compute capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
Nicolas Patry's avatar
Nicolas Patry committed
99
100
        HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
        HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
101
102
103
104
105
106
107
108
109
    except ImportError as e:
        try:
            import flash_attn_cuda
        except ImportError:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
110

Nicolas Patry's avatar
Nicolas Patry committed
111
        if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
112
113
114
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            ) from e
Nicolas Patry's avatar
Nicolas Patry committed
115
        elif SYSTEM == "rocm":
116
117
118
119
120
121
122
            for idx in range(torch.cuda.device_count()):
                if "MI210" not in torch.cuda.get_device_name(
                    idx
                ) and "MI250" not in torch.cuda.get_device_name(idx):
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
fxmarty's avatar
fxmarty committed
123

124
125
        logger.warning(f"Unable to use Flash Attention V2: {e}")
        HAS_FLASH_ATTN = True
126
127


Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
131
132
133
134
135
136
137
138
if HAS_FLASH_ATTN_V2_CUDA:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
drbh's avatar
drbh committed
139
        causal=True,
Nicolas Patry's avatar
Nicolas Patry committed
140
141
142
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
143
144
145
146
147
148
149
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
150
151
152
            None,
            None,
            None,
153
154
155
156
157
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
drbh's avatar
drbh committed
158
            causal,
159
160
            window_size_left,
            0,
161
162
163
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
164

fxmarty's avatar
fxmarty committed
165
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
Nicolas Patry's avatar
Nicolas Patry committed
166
167
168
169
170
171
172
173
174
175

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
fxmarty's avatar
fxmarty committed
176
        causal=True,
Nicolas Patry's avatar
Nicolas Patry committed
177
178
179
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
fxmarty's avatar
fxmarty committed
180
        if window_size_left != -1:
OlivierDehaene's avatar
OlivierDehaene committed
181
182
183
184
            raise ValueError(
                f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )

fxmarty's avatar
fxmarty committed
185
186
187
188
189
190
191
192
193
194
195
196
197
        # RoCm flash API does not take the window_size_left and window_size_right arguments.
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
fxmarty's avatar
fxmarty committed
198
            causal,
fxmarty's avatar
fxmarty committed
199
200
201
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
202

fxmarty's avatar
fxmarty committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
        causal=True,
    ):
        output, _ = triton_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            causal,
            softmax_scale,
        )
        return output

Nicolas Patry's avatar
Nicolas Patry committed
230
231
232
233
234
235
236
237
238
239
240
241
elif HAS_FLASH_ATTN:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
242
        if window_size_left != -1:
243
244
245
246
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # Flash attention v1 requires q, k and v to have the same number of heads
        if k.shape[1] != q.shape[1]:
            # MQA expand
            if k.shape[1] == 1:
                k = k.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = k.shape
                k = (
                    k.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // k.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )
        if v.shape[1] != q.shape[1]:
            # MQA expand
            if v.shape[1] == 1:
                v = v.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = v.shape
                v = (
                    v.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // v.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )

        return flash_attn_cuda.fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            0,
            None,
        )

Nicolas Patry's avatar
Nicolas Patry committed
291
else:
292
    raise NotImplementedError("flash attention is not installed")