flash_attn.py 6.8 KB
Newer Older
1
2
3
4
5
import os
import torch

from loguru import logger

Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation_server.utils.import_utils import SYSTEM
fxmarty's avatar
fxmarty committed
7

8
9
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")
10
11
12
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
13

Nicolas Patry's avatar
Nicolas Patry committed
14
if SYSTEM == "xpu":
15
    import intel_extension_for_pytorch as ipex
16

Nicolas Patry's avatar
Nicolas Patry committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")

        if window_size_left != -1:
            raise ValueError(
                f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )
        return ipex.llm.functional.varlen_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )


if SYSTEM in {"cuda", "rocm"}:
53
54
    if not torch.cuda.is_available():
        raise ImportError("CUDA is not available")
55

56
57
58
59
60
61
62
63
    major, minor = torch.cuda.get_device_capability()
    is_sm75 = major == 7 and minor == 5
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0

    HAS_FLASH_ATTN = False
    HAS_FLASH_ATTN_V2_CUDA = False
    HAS_FLASH_ATTN_V2_ROCM = False
64
    try:
65
66
67
        try:
            import flash_attn_2_cuda
        except ImportError:
Nicolas Patry's avatar
Nicolas Patry committed
68
            architecture_suffix = f"-{SYSTEM}"
69
70
71
72
73
74
75
76
77
78
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        if not (is_sm8x or is_sm90):
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
Nicolas Patry's avatar
Nicolas Patry committed
79
80
        HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
        HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
81
82
83
84
85
86
87
88
89
    except ImportError as e:
        try:
            import flash_attn_cuda
        except ImportError:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
90

Nicolas Patry's avatar
Nicolas Patry committed
91
        if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
92
93
94
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            ) from e
Nicolas Patry's avatar
Nicolas Patry committed
95
        elif SYSTEM == "rocm":
96
97
98
99
100
101
102
            for idx in range(torch.cuda.device_count()):
                if "MI210" not in torch.cuda.get_device_name(
                    idx
                ) and "MI250" not in torch.cuda.get_device_name(idx):
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
fxmarty's avatar
fxmarty committed
103

104
105
        logger.warning(f"Unable to use Flash Attention V2: {e}")
        HAS_FLASH_ATTN = True
106
107


Nicolas Patry's avatar
Nicolas Patry committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
if HAS_FLASH_ATTN_V2_CUDA:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
122
123
124
125
126
127
128
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
129
130
131
            None,
            None,
            None,
132
133
134
135
136
137
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
138
139
            window_size_left,
            0,
140
141
142
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

elif HAS_FLASH_ATTN_V2_ROCM:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
        if window_size_left <= 0 and window_size_left != -1:
            raise ValueError("`window_size_left` must be > 0 or -1")
fxmarty's avatar
fxmarty committed
158
        if window_size_left != -1:
OlivierDehaene's avatar
OlivierDehaene committed
159
160
161
162
            raise ValueError(
                f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )

fxmarty's avatar
fxmarty committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        # RoCm flash API does not take the window_size_left and window_size_right arguments.
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )
Nicolas Patry's avatar
Nicolas Patry committed
180
181
182
183
184
185
186
187
188
189
190
191
192

elif HAS_FLASH_ATTN:

    def attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        max_s,
        softmax_scale,
        window_size_left=-1,
    ):
193
        if window_size_left != -1:
194
195
196
197
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        # Flash attention v1 requires q, k and v to have the same number of heads
        if k.shape[1] != q.shape[1]:
            # MQA expand
            if k.shape[1] == 1:
                k = k.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = k.shape
                k = (
                    k.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // k.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )
        if v.shape[1] != q.shape[1]:
            # MQA expand
            if v.shape[1] == 1:
                v = v.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = v.shape
                v = (
                    v.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // v.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )

        return flash_attn_cuda.fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            0,
            None,
        )

Nicolas Patry's avatar
Nicolas Patry committed
242
else:
243
    raise NotImplementedError("flash attention is not installed")