flash_attn.py 4.89 KB
Newer Older
1
2
3
4
5
import os
import torch

from loguru import logger

fxmarty's avatar
fxmarty committed
6
7
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM

8
9
10
11
12
13
14
15
16
17
18
19
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")

if not torch.cuda.is_available():
    raise ImportError("CUDA is not available")

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0

HAS_FLASH_ATTN = False
fxmarty's avatar
fxmarty committed
20
21
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
22
23
24
25
26
27
28
29
30
31
32
33
34
35
try:
    try:
        import flash_attn_2_cuda
    except ImportError:
        raise ImportError(
            "Flash Attention V2 is not installed.\n"
            "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
            "or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
        )
    if not (is_sm8x or is_sm90):
        raise ImportError(
            f"GPU with CUDA capability {major} {minor} is not supported for "
            "Flash Attention V2"
        )
fxmarty's avatar
fxmarty committed
36
37
    HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
    HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
38
39
40
41
42
43
44
45
46
47
except ImportError as e:
    try:
        import flash_attn_cuda
    except ImportError:
        raise ImportError(
            "Flash Attention is not installed.\n"
            "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
            "or install flash attention with `cd server && make install install-flash-attention`"
        ) from e

fxmarty's avatar
fxmarty committed
48
    if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
49
50
51
        raise ImportError(
            f"GPU with CUDA capability {major} {minor} is not supported"
        ) from e
fxmarty's avatar
fxmarty committed
52
53
    elif IS_ROCM_SYSTEM:
        for idx in range(torch.cuda.device_count()):
OlivierDehaene's avatar
OlivierDehaene committed
54
55
56
            if "MI210" not in torch.cuda.get_device_name(
                idx
            ) and "MI250" not in torch.cuda.get_device_name(idx):
fxmarty's avatar
fxmarty committed
57
58
59
60
                raise ImportError(
                    f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
                )

61
62
63
64
65
66
67
68
69
70
71
72
    logger.warning(f"Unable to use Flash Attention V2: {e}")
    HAS_FLASH_ATTN = True


def attention(
    q,
    k,
    v,
    out,
    cu_seqlens,
    max_s,
    softmax_scale,
73
    window_size_left=-1,
74
):
fxmarty's avatar
fxmarty committed
75
    if HAS_FLASH_ATTN_V2_CUDA:
76
77
78
79
80
81
82
83
84
85
86
87
88
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
89
90
            window_size_left,
            0,
91
92
93
            False,
            None,
        )
fxmarty's avatar
fxmarty committed
94
95
    elif HAS_FLASH_ATTN_V2_ROCM:
        if window_size_left != -1:
OlivierDehaene's avatar
OlivierDehaene committed
96
97
98
99
            raise ValueError(
                f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
            )

fxmarty's avatar
fxmarty committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        # RoCm flash API does not take the window_size_left and window_size_right arguments.
        return flash_attn_2_cuda.varlen_fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            None,
        )
    elif HAS_FLASH_ATTN:
118
        if window_size_left != -1:
119
120
121
122
            raise NotImplementedError(
                "window_size_left is only available with flash attn v2"
            )

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        # Flash attention v1 requires q, k and v to have the same number of heads
        if k.shape[1] != q.shape[1]:
            # MQA expand
            if k.shape[1] == 1:
                k = k.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = k.shape
                k = (
                    k.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // k.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )
        if v.shape[1] != q.shape[1]:
            # MQA expand
            if v.shape[1] == 1:
                v = v.expand(-1, q.shape[1], -1)
            # Grouped attention reshape
            else:
                original_shape = v.shape
                v = (
                    v.unsqueeze(2)
                    .expand(-1, -1, q.shape[1] // v.shape[1], -1)
                    .reshape(original_shape[0], -1, original_shape[2])
                )

        return flash_attn_cuda.fwd(
            q,
            k,
            v,
            out,
            cu_seqlens,
            cu_seqlens,
            max_s,
            max_s,
            0.0,
            softmax_scale,
            False,
            True,
            False,
            0,
            None,
        )

    raise NotImplementedError("flash attention is not installed")