__init__.py 10.5 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
61
    FLASH_ATTENTION = False
62

63
if FLASH_ATTENTION:
64
    __all__.append(FlashNeoXSharded)
65
    __all__.append(FlashRWSharded)
66
    __all__.append(FlashSantacoderSharded)
67
    __all__.append(FlashLlama)
68
    __all__.append(IDEFICSSharded)
69

70
71
72
73
74
75
76
77
78
79
MISTRAL = True
try:
    from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
    logger.warning(f"Could not import Mistral model: {e}")
    MISTRAL = False

if MISTRAL:
    __all__.append(FlashMistral)

80

81
def get_model(
82
83
84
85
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
86
    dtype: Optional[str],
87
    trust_remote_code: bool,
88
) -> Model:
89
90
91
92
93
94
95
96
97
    if dtype is None:
        dtype = torch.float16
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

98
    if "facebook/galactica" in model_id:
99
        return GalacticaSharded(
100
101
102
103
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
104
            trust_remote_code=trust_remote_code,
105
        )
106

107
    if model_id.startswith("bigcode/"):
108
        if FLASH_ATTENTION:
109
110
111
112
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
113
                dtype=dtype,
114
115
                trust_remote_code=trust_remote_code,
            )
116
117
118
119
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
120
        else:
121
            return SantaCoder(
122
123
124
                model_id,
                revision,
                quantize=quantize,
125
                dtype=dtype,
126
127
                trust_remote_code=trust_remote_code,
            )
128

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
129
130
131
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
132
    model_type = config_dict["model_type"]
133

134
    if model_type == "gpt_bigcode":
135
        if FLASH_ATTENTION:
136
137
138
139
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
140
                dtype=dtype,
141
142
                trust_remote_code=trust_remote_code,
            )
143
144
145
146
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
147
        else:
148
            return SantaCoder(
149
150
151
                model_id,
                revision,
                quantize=quantize,
152
                dtype=dtype,
153
154
                trust_remote_code=trust_remote_code,
            )
155

156
    if model_type == "bloom":
157
        return BLOOMSharded(
158
159
160
161
162
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
163
        )
164
165
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
166
167
168
169
170
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
171
        )
172
173
174
175
176
177
178

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
179
                dtype=dtype,
180
181
182
183
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
184
185
186
                model_id,
                revision,
                quantize=quantize,
187
                dtype=dtype,
188
189
                trust_remote_code=trust_remote_code,
            )
190
        else:
191
            return CausalLM(
192
193
194
                model_id,
                revision,
                quantize=quantize,
195
                dtype=dtype,
196
197
                trust_remote_code=trust_remote_code,
            )
198

xiaobin's avatar
xiaobin committed
199
    elif model_type == "llama" or model_type == "baichuan":
200
201
        if FLASH_ATTENTION:
            return FlashLlama(
202
203
204
                model_id,
                revision,
                quantize=quantize,
205
                dtype=dtype,
206
207
                trust_remote_code=trust_remote_code,
            )
208
209
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
210
        else:
211
            return CausalLM(
212
213
214
                model_id,
                revision,
                quantize=quantize,
215
                dtype=dtype,
216
217
                trust_remote_code=trust_remote_code,
            )
218

219
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
220
221
        if sharded:
            if FLASH_ATTENTION:
222
                if config_dict.get("alibi", False):
223
224
225
226
227
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
228
                    dtype=dtype,
229
230
                    trust_remote_code=trust_remote_code,
                )
231
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
232
        else:
233
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
234
                return FlashRWSharded(
235
236
237
                    model_id,
                    revision,
                    quantize=quantize,
238
                    dtype=dtype,
239
240
241
242
243
244
245
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
246
                    dtype=dtype,
247
248
249
                    trust_remote_code=trust_remote_code,
                )

250
251
252
253
254
255
256
257
258
259
260
261
    if model_type == "mistral":
        if MISTRAL:
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        raise NotImplementedError("Mistral model requires flash attention v2")

    if model_type == "opt":
262
        return OPTSharded(
263
264
265
266
267
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
268
        )
269

270
    if model_type == "t5":
271
272
273
274
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
275
            dtype=dtype,
276
277
            trust_remote_code=trust_remote_code,
        )
278
    if model_type == "idefics":
279
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
280
281
282
283
284
285
286
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
287
288
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
289
290
291

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
292
293
294
295
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
296
    if quantize == "awq":
OlivierDehaene's avatar
OlivierDehaene committed
297
        raise ValueError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
298
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
OlivierDehaene's avatar
OlivierDehaene committed
299
        raise ValueError("4bit quantization is not supported for AutoModel")
300
301
    elif (quantize == "eetq"):
        raise ValueError("Eetq quantization is not supported for AutoModel")
302
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
303
        return CausalLM(
304
305
306
307
308
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
309
        )
310
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
311
        return Seq2SeqLM(
312
313
314
315
316
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
317
318
        )

319
    auto_map = config_dict.get("auto_map", None)
320
321
322
323
324
325
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
326
                dtype=dtype,
327
328
                trust_remote_code=trust_remote_code,
            )
329
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
330
331
332
333
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
334
                dtype=dtype,
335
336
                trust_remote_code=trust_remote_code,
            )
337
338

    raise ValueError(f"Unsupported model type {model_type}")