__init__.py 11.8 KB
Newer Older
1
import os
2
3
import torch

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
7
8
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
9
from text_generation_server.utils.speculate import get_speculate, set_speculate
10
11
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
12
from text_generation_server.models.flash_causal_lm import FlashCausalLM
13
from text_generation_server.models.bloom import BLOOMSharded
14
from text_generation_server.models.mpt import MPTSharded
15
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
16
from text_generation_server.models.rw import RW
17
18
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
19
20
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
21
from text_generation_server.models.gpt_neox import GPTNeoxSharded
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

46
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
47

48
FLASH_ATTENTION = True
49
try:
50
51
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
57
    )
58
    from text_generation_server.models.idefics import IDEFICSSharded
59
60
61

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
62
    FLASH_ATTENTION = False
63

64
if FLASH_ATTENTION:
65
    __all__.append(FlashNeoXSharded)
66
    __all__.append(FlashRWSharded)
67
    __all__.append(FlashSantacoderSharded)
68
    __all__.append(FlashLlama)
69
    __all__.append(IDEFICSSharded)
70

71
72
73
74
75
76
77
78
79
80
MISTRAL = True
try:
    from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
    logger.warning(f"Could not import Mistral model: {e}")
    MISTRAL = False

if MISTRAL:
    __all__.append(FlashMistral)

81
def get_model(
82
83
84
85
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
86
    speculate: Optional[int],
87
    dtype: Optional[str],
88
    trust_remote_code: bool,
89
) -> Model:
90
    if dtype is None:
91
92
93
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
94
95
96
97
98
99
100
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
104
105
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

106
    if "facebook/galactica" in model_id:
107
        return GalacticaSharded(
108
109
110
111
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
112
            trust_remote_code=trust_remote_code,
113
        )
114

115
    if model_id.startswith("bigcode/"):
116
        if FLASH_ATTENTION:
117
118
119
120
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
121
                dtype=dtype,
122
123
                trust_remote_code=trust_remote_code,
            )
124
125
126
127
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
128
        else:
129
            return SantaCoder(
130
131
132
                model_id,
                revision,
                quantize=quantize,
133
                dtype=dtype,
134
135
                trust_remote_code=trust_remote_code,
            )
136

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
137
138
139
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        medusa_config = config_dict
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
                raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

167
    model_type = config_dict["model_type"]
168

169
    if model_type == "gpt_bigcode":
170
        if FLASH_ATTENTION:
171
172
173
174
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
175
                dtype=dtype,
176
177
                trust_remote_code=trust_remote_code,
            )
178
179
180
181
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
182
        else:
183
            return SantaCoder(
184
185
186
                model_id,
                revision,
                quantize=quantize,
187
                dtype=dtype,
188
189
                trust_remote_code=trust_remote_code,
            )
190

191
    if model_type == "bloom":
192
        return BLOOMSharded(
193
194
195
196
197
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
198
        )
199
200
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
201
202
203
204
205
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
206
        )
207
208
209
210
211
212
213

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
214
                dtype=dtype,
215
216
217
218
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
219
220
221
                model_id,
                revision,
                quantize=quantize,
222
                dtype=dtype,
223
224
                trust_remote_code=trust_remote_code,
            )
225
        else:
226
            return CausalLM(
227
228
229
                model_id,
                revision,
                quantize=quantize,
230
                dtype=dtype,
231
232
                trust_remote_code=trust_remote_code,
            )
233

xiaobin's avatar
xiaobin committed
234
    elif model_type == "llama" or model_type == "baichuan":
235
236
        if FLASH_ATTENTION:
            return FlashLlama(
237
238
239
                model_id,
                revision,
                quantize=quantize,
240
                dtype=dtype,
241
                trust_remote_code=trust_remote_code,
Nicolas Patry's avatar
Nicolas Patry committed
242
                use_medusa=use_medusa
243
            )
244
245
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
246
        else:
247
            return CausalLM(
248
249
250
                model_id,
                revision,
                quantize=quantize,
251
                dtype=dtype,
252
253
                trust_remote_code=trust_remote_code,
            )
254

255
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
256
257
        if sharded:
            if FLASH_ATTENTION:
258
                if config_dict.get("alibi", False):
259
260
261
262
263
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
264
                    dtype=dtype,
265
266
                    trust_remote_code=trust_remote_code,
                )
267
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
268
        else:
269
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
270
                return FlashRWSharded(
271
272
273
                    model_id,
                    revision,
                    quantize=quantize,
274
                    dtype=dtype,
275
276
277
278
279
280
281
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
282
                    dtype=dtype,
283
284
285
                    trust_remote_code=trust_remote_code,
                )

286
287
288
289
290
291
292
293
294
295
296
297
    if model_type == "mistral":
        if MISTRAL:
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        raise NotImplementedError("Mistral model requires flash attention v2")

    if model_type == "opt":
298
        return OPTSharded(
299
300
301
302
303
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
304
        )
305

306
    if model_type == "t5":
307
308
309
310
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
311
            dtype=dtype,
312
313
            trust_remote_code=trust_remote_code,
        )
314
    if model_type == "idefics":
315
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
316
317
318
319
320
321
322
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
323
324
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
325
326
327

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
328
329
330
331
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
332
    if quantize == "awq":
OlivierDehaene's avatar
OlivierDehaene committed
333
        raise ValueError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
334
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
OlivierDehaene's avatar
OlivierDehaene committed
335
        raise ValueError("4bit quantization is not supported for AutoModel")
336
337
    elif (quantize == "eetq"):
        raise ValueError("Eetq quantization is not supported for AutoModel")
338
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
339
        return CausalLM(
340
341
342
343
344
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
345
        )
346
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
347
        return Seq2SeqLM(
348
349
350
351
352
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
353
354
        )

355
    auto_map = config_dict.get("auto_map", None)
356
357
358
359
360
361
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
362
                dtype=dtype,
363
364
                trust_remote_code=trust_remote_code,
            )
365
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
366
367
368
369
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
370
                dtype=dtype,
371
372
                trust_remote_code=trust_remote_code,
            )
373
374

    raise ValueError(f"Unsupported model type {model_type}")