__init__.py 9.67 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
61
    FLASH_ATTENTION = False
62

63
if FLASH_ATTENTION:
64
    __all__.append(FlashNeoXSharded)
65
    __all__.append(FlashRWSharded)
66
    __all__.append(FlashSantacoderSharded)
67
    __all__.append(FlashLlama)
68
    __all__.append(IDEFICSSharded)
69

70

71
def get_model(
72
73
74
75
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
76
    dtype: Optional[str],
77
    trust_remote_code: bool,
78
) -> Model:
79
80
81
82
83
84
85
86
87
    if dtype is None:
        dtype = torch.float16
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

88
    if "facebook/galactica" in model_id:
89
        return GalacticaSharded(
90
91
92
93
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
94
            trust_remote_code=trust_remote_code,
95
        )
96

97
    if model_id.startswith("bigcode/"):
98
        if FLASH_ATTENTION:
99
100
101
102
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
103
                dtype=dtype,
104
105
                trust_remote_code=trust_remote_code,
            )
106
107
108
109
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
110
        else:
111
            return SantaCoder(
112
113
114
                model_id,
                revision,
                quantize=quantize,
115
                dtype=dtype,
116
117
                trust_remote_code=trust_remote_code,
            )
118

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
119
120
121
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
122
    model_type = config_dict["model_type"]
123

124
    if model_type == "gpt_bigcode":
125
        if FLASH_ATTENTION:
126
127
128
129
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
130
                dtype=dtype,
131
132
                trust_remote_code=trust_remote_code,
            )
133
134
135
136
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
137
        else:
138
            return SantaCoder(
139
140
141
                model_id,
                revision,
                quantize=quantize,
142
                dtype=dtype,
143
144
                trust_remote_code=trust_remote_code,
            )
145

146
    if model_type == "bloom":
147
        return BLOOMSharded(
148
149
150
151
152
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
153
        )
154
155
    elif model_type == "mpt":
        return MPTSharded(
Wang, Yi's avatar
Wang, Yi committed
156
            model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
157
        )
158
159
160
161
162
163
164

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
165
                dtype=dtype,
166
167
168
169
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
170
171
172
                model_id,
                revision,
                quantize=quantize,
173
                dtype=dtype,
174
175
                trust_remote_code=trust_remote_code,
            )
176
        else:
177
            return CausalLM(
178
179
180
                model_id,
                revision,
                quantize=quantize,
181
                dtype=dtype,
182
183
                trust_remote_code=trust_remote_code,
            )
184

xiaobin's avatar
xiaobin committed
185
    elif model_type == "llama" or model_type == "baichuan":
186
187
        if FLASH_ATTENTION:
            return FlashLlama(
188
189
190
                model_id,
                revision,
                quantize=quantize,
191
                dtype=dtype,
192
193
                trust_remote_code=trust_remote_code,
            )
194
195
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
196
        else:
197
            return CausalLM(
198
199
200
                model_id,
                revision,
                quantize=quantize,
201
                dtype=dtype,
202
203
                trust_remote_code=trust_remote_code,
            )
204

205
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
206
207
        if sharded:
            if FLASH_ATTENTION:
208
                if config_dict.get("alibi", False):
209
210
211
212
213
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
214
                    dtype=dtype,
215
216
                    trust_remote_code=trust_remote_code,
                )
217
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
218
        else:
219
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
220
                return FlashRWSharded(
221
222
223
                    model_id,
                    revision,
                    quantize=quantize,
224
                    dtype=dtype,
225
226
227
228
229
230
231
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
232
                    dtype=dtype,
233
234
235
                    trust_remote_code=trust_remote_code,
                )

236
237
    elif model_type == "opt":
        return OPTSharded(
238
239
240
241
242
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
243
        )
244

245
    elif model_type == "t5":
246
247
248
249
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
250
            dtype=dtype,
251
252
            trust_remote_code=trust_remote_code,
        )
253
254
255
256
257
258
259
260
261
262
263
    elif model_type == "idefics":
        if FLASH_ATTENTION:
           return IDEFICSSharded(
               model_id,
               revision,
               quantize=quantize,
               dtype=dtype,
               trust_remote_code=trust_remote_code,
           )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
264
265
266

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
267
268
269
270
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
Nicolas Patry's avatar
Nicolas Patry committed
271
272
273
274
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
        raise ValueError(
            "4bit quantization is not supported for AutoModel"
        )
275
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
276
        return CausalLM(
277
278
279
280
281
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
282
        )
283
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
284
        return Seq2SeqLM(
285
286
287
288
289
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
290
291
        )

292
    auto_map = config_dict.get("auto_map", None)
293
294
295
296
297
298
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
299
                dtype=dtype,
300
301
                trust_remote_code=trust_remote_code,
            )
302
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
303
304
305
306
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
307
                dtype=dtype,
308
309
                trust_remote_code=trust_remote_code,
            )
310
311

    raise ValueError(f"Unsupported model type {model_type}")