__init__.py 8.78 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
14
from text_generation_server.models.rw import RW
15
16
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
17
18
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
19
from text_generation_server.models.gpt_neox import GPTNeoxSharded
20

21
try:
22
23
24
25
    if (
        torch.cuda.is_available()
        and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
    ):
26
27
28
29
30
31
32
        major, minor = torch.cuda.get_device_capability()
        is_sm75 = major == 7 and minor == 5
        is_sm8x = major == 8 and minor >= 0
        is_sm90 = major == 9 and minor == 0

        supported = is_sm75 or is_sm8x or is_sm90
        if not supported:
33
34
35
36
            raise ImportError(
                f"GPU with CUDA capability {major} {minor} is not supported"
            )

37
38
        from text_generation_server.models.flash_rw import FlashRWSharded
        from text_generation_server.models.flash_neox import FlashNeoXSharded
39
40
41
42
43
44
45
        from text_generation_server.models.flash_llama import (
            FlashLlama,
        )
        from text_generation_server.models.flash_santacoder import (
            FlashSantacoderSharded,
        )

46
47
48
        FLASH_ATTENTION = True
    else:
        FLASH_ATTENTION = False
49
except ImportError:
50
51
52
    logger.opt(exception=True).warning(
        "Could not import Flash Attention enabled models"
    )
53
    FLASH_ATTENTION = False
54

55
56
57
58
__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
59
    "FlashCausalLM",
60
    "GalacticaSharded",
61
62
    "Seq2SeqLM",
    "SantaCoder",
63
    "OPTSharded",
64
    "T5Sharded",
65
66
67
    "get_model",
]

68
if FLASH_ATTENTION:
69
    __all__.append(FlashNeoXSharded)
70
    __all__.append(FlashRWSharded)
71
    __all__.append(FlashSantacoderSharded)
72
73
    __all__.append(FlashLlama)

74
75
76
77
78
FLASH_ATT_ERROR_MESSAGE = (
    "{} requires Flash Attention CUDA kernels to be installed.\n"
    "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
    "or install flash attention with `cd server && make install install-flash-attention`"
)
79

80
81
82
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
83

84
85
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
86

87
88
89
# Disable gradients
torch.set_grad_enabled(False)

90

91
def get_model(
92
93
94
95
96
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    trust_remote_code: bool,
97
) -> Model:
98
    if "facebook/galactica" in model_id:
99
100
101
        return GalacticaSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
102

103
    if model_id.startswith("bigcode/"):
104
        if FLASH_ATTENTION:
105
106
107
108
109
110
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
111
112
113
114
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
115
        else:
116
            return SantaCoder(
117
118
119
120
121
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
122

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
123
124
125
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
126
    model_type = config_dict["model_type"]
127

128
    if model_type == "gpt_bigcode":
129
        if FLASH_ATTENTION:
130
131
132
133
134
135
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
136
137
138
139
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
140
        else:
141
            return SantaCoder(
142
143
144
145
146
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
147

148
    if model_type == "bloom":
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        return BLOOMSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
163
164
165
166
167
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
168
        else:
169
            return CausalLM(
170
171
172
173
174
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
175

176
177
178
    elif model_type == "llama":
        if FLASH_ATTENTION:
            return FlashLlama(
179
180
181
182
183
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
184
185
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
186
        else:
187
            return CausalLM(
188
189
190
191
192
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
193

194
195
196
    if model_type in ["RefinedWeb", "RefinedWebModel"]:
        if sharded:
            if FLASH_ATTENTION:
197
198
199
                if config_dict.get("alibi", False) or (
                    model_type == "RefinedWebModel"
                    and config_dict.get("multi_query", True)
200
201
202
203
204
205
206
207
208
209
210
211
                ):
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
            )
        else:
212
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
213
                return FlashRWSharded(
214
215
216
217
218
219
220
221
222
223
224
225
226
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
                    trust_remote_code=trust_remote_code,
                )

227
228
229
230
    elif model_type == "opt":
        return OPTSharded(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
231

232
    elif model_type == "t5":
233
        if sharded:
234
235
236
237
238
239
            return T5Sharded(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
240
        else:
241
242
243
244
245
246
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
247
248
249

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
250
251

    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
252
253
254
        return CausalLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )
255
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
256
257
258
259
        return Seq2SeqLM(
            model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
        )

260
    auto_map = config_dict.get("auto_map", None)
261
262
263
264
265
266
267
268
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
269
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
270
271
272
273
274
275
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
                trust_remote_code=trust_remote_code,
            )
276
277

    raise ValueError(f"Unsupported model type {model_type}")