flash_attn.py 6.26 KB
Newer Older
1
2
3
4
import os
import torch

from loguru import logger
5
import math
6

7
8
9
10
11
from text_generation_server.utils.import_utils import (
    IS_CUDA_SYSTEM,
    IS_ROCM_SYSTEM,
    IS_XPU_SYSTEM,
)
fxmarty's avatar
fxmarty committed
12

13
14
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")
15
16
17
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
18

19
20
if IS_XPU_SYSTEM:
    import intel_extension_for_pytorch as ipex
21

22
23
24
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
    if not torch.cuda.is_available():
        raise ImportError("CUDA is not available")
25

26
27
28
29
30
31
32
33
    major, minor = torch.cuda.get_device_capability()
    is_sm75 = major == 7 and minor == 5
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0

    HAS_FLASH_ATTN = False
    HAS_FLASH_ATTN_V2_CUDA = False
    HAS_FLASH_ATTN_V2_ROCM = False
34
    try:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        try:
            import flash_attn_2_cuda
        except ImportError:
            architecture_suffix = ""
            if IS_CUDA_SYSTEM:
                architecture_suffix = "-cuda"
            elif IS_ROCM_SYSTEM:
                architecture_suffix = "-rocm"
            raise ImportError(
                "Flash Attention V2 is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
            )
        if not (is_sm8x or is_sm90):
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported for "
                "Flash Attention V2"
            )
        HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
        HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
    except ImportError as e:
        try:
            import flash_attn_cuda
        except ImportError:
            raise ImportError(
                "Flash Attention is not installed.\n"
                "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
                "or install flash attention with `cd server && make install install-flash-attention`"
            ) from e
64

65
66
67
68
69
70
71
72
73
74
75
76
        if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            ) from e
        elif IS_ROCM_SYSTEM:
            for idx in range(torch.cuda.device_count()):
                if "MI210" not in torch.cuda.get_device_name(
                    idx
                ) and "MI250" not in torch.cuda.get_device_name(idx):
                    raise ImportError(
                        f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                    )
fxmarty's avatar
fxmarty committed
77

78
79
        logger.warning(f"Unable to use Flash Attention V2: {e}")
        HAS_FLASH_ATTN = True
80
81
82
83
84
85
86
87
88
89


def attention(
    q,
    k,
    v,
    out,
    cu_seqlens,
    max_s,
    softmax_scale,
90
    window_size_left=-1,
91
):
92
93
94
    if window_size_left <= 0 and window_size_left != -1:
        raise ValueError("`window_size_left` must be > 0 or -1")

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    if IS_XPU_SYSTEM:
        if window_size_left != -1:
            raise ValueError(
                f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )
        return ipex.llm.functional.varlen_attention(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )

fxmarty's avatar
fxmarty committed
117
    if HAS_FLASH_ATTN_V2_CUDA:
118
119
120
121
122
123
124
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
125
126
127
            None,
            None,
            None,
128
129
130
131
132
133
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
134
135
            window_size_left,
            0,
136
137
138
            False,
            None,
        )
fxmarty's avatar
fxmarty committed
139
140
    elif HAS_FLASH_ATTN_V2_ROCM:
        if window_size_left != -1:
OlivierDehaene's avatar
OlivierDehaene committed
141
142
143
144
            raise ValueError(
                f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )

fxmarty's avatar
fxmarty committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        # RoCm flash API does not take the window_size_left and window_size_right arguments.
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )
    elif HAS_FLASH_ATTN:
163
        if window_size_left != -1:
164
165
166
167
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        # Flash attention v1 requires q, k and v to have the same number of heads
        if k.shape[1] != q.shape[1]:
            # MQA expand
            if k.shape[1] == 1:
                k = k.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = k.shape
                k = (
                    k.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // k.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )
        if v.shape[1] != q.shape[1]:
            # MQA expand
            if v.shape[1] == 1:
                v = v.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = v.shape
                v = (
                    v.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // v.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )

        return flash_attn_cuda.fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            0,
            None,
        )

    raise NotImplementedError("flash attention is not installed")