__init__.py 10.6 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
61
    FLASH_ATTENTION = False
62

63
if FLASH_ATTENTION:
64
    __all__.append(FlashNeoXSharded)
65
    __all__.append(FlashRWSharded)
66
    __all__.append(FlashSantacoderSharded)
67
    __all__.append(FlashLlama)
68
    __all__.append(IDEFICSSharded)
69

70
71
72
73
74
75
76
77
78
79
MISTRAL = True
try:
    from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
    logger.warning(f"Could not import Mistral model: {e}")
    MISTRAL = False

if MISTRAL:
    __all__.append(FlashMistral)

80

81
def get_model(
82
83
84
85
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
86
    dtype: Optional[str],
87
    trust_remote_code: bool,
88
) -> Model:
89
    if dtype is None:
90
91
92
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
93
94
95
96
97
98
99
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

100
    if "facebook/galactica" in model_id:
101
        return GalacticaSharded(
102
103
104
105
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
106
            trust_remote_code=trust_remote_code,
107
        )
108

109
    if model_id.startswith("bigcode/"):
110
        if FLASH_ATTENTION:
111
112
113
114
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
115
                dtype=dtype,
116
117
                trust_remote_code=trust_remote_code,
            )
118
119
120
121
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
122
        else:
123
            return SantaCoder(
124
125
126
                model_id,
                revision,
                quantize=quantize,
127
                dtype=dtype,
128
129
                trust_remote_code=trust_remote_code,
            )
130

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
131
132
133
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
134
    model_type = config_dict["model_type"]
135

136
    if model_type == "gpt_bigcode":
137
        if FLASH_ATTENTION:
138
139
140
141
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
142
                dtype=dtype,
143
144
                trust_remote_code=trust_remote_code,
            )
145
146
147
148
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
149
        else:
150
            return SantaCoder(
151
152
153
                model_id,
                revision,
                quantize=quantize,
154
                dtype=dtype,
155
156
                trust_remote_code=trust_remote_code,
            )
157

158
    if model_type == "bloom":
159
        return BLOOMSharded(
160
161
162
163
164
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
165
        )
166
167
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
168
169
170
171
172
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
173
        )
174
175
176
177
178
179
180

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
181
                dtype=dtype,
182
183
184
185
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
186
187
188
                model_id,
                revision,
                quantize=quantize,
189
                dtype=dtype,
190
191
                trust_remote_code=trust_remote_code,
            )
192
        else:
193
            return CausalLM(
194
195
196
                model_id,
                revision,
                quantize=quantize,
197
                dtype=dtype,
198
199
                trust_remote_code=trust_remote_code,
            )
200

xiaobin's avatar
xiaobin committed
201
    elif model_type == "llama" or model_type == "baichuan":
202
203
        if FLASH_ATTENTION:
            return FlashLlama(
204
205
206
                model_id,
                revision,
                quantize=quantize,
207
                dtype=dtype,
208
209
                trust_remote_code=trust_remote_code,
            )
210
211
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
212
        else:
213
            return CausalLM(
214
215
216
                model_id,
                revision,
                quantize=quantize,
217
                dtype=dtype,
218
219
                trust_remote_code=trust_remote_code,
            )
220

221
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
222
223
        if sharded:
            if FLASH_ATTENTION:
224
                if config_dict.get("alibi", False):
225
226
227
228
229
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
230
                    dtype=dtype,
231
232
                    trust_remote_code=trust_remote_code,
                )
233
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
234
        else:
235
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
236
                return FlashRWSharded(
237
238
239
                    model_id,
                    revision,
                    quantize=quantize,
240
                    dtype=dtype,
241
242
243
244
245
246
247
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
248
                    dtype=dtype,
249
250
251
                    trust_remote_code=trust_remote_code,
                )

252
253
254
255
256
257
258
259
260
261
262
263
    if model_type == "mistral":
        if MISTRAL:
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        raise NotImplementedError("Mistral model requires flash attention v2")

    if model_type == "opt":
264
        return OPTSharded(
265
266
267
268
269
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
270
        )
271

272
    if model_type == "t5":
273
274
275
276
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
277
            dtype=dtype,
278
279
            trust_remote_code=trust_remote_code,
        )
280
    if model_type == "idefics":
281
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
282
283
284
285
286
287
288
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
289
290
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
291
292
293

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
294
295
296
297
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
298
    if quantize == "awq":
OlivierDehaene's avatar
OlivierDehaene committed
299
        raise ValueError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
300
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
OlivierDehaene's avatar
OlivierDehaene committed
301
        raise ValueError("4bit quantization is not supported for AutoModel")
302
303
    elif (quantize == "eetq"):
        raise ValueError("Eetq quantization is not supported for AutoModel")
304
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
305
        return CausalLM(
306
307
308
309
310
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
311
        )
312
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
313
        return Seq2SeqLM(
314
315
316
317
318
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
319
320
        )

321
    auto_map = config_dict.get("auto_map", None)
322
323
324
325
326
327
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
328
                dtype=dtype,
329
330
                trust_remote_code=trust_remote_code,
            )
331
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
332
333
334
335
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
336
                dtype=dtype,
337
338
                trust_remote_code=trust_remote_code,
            )
339
340

    raise ValueError(f"Unsupported model type {model_type}")