__init__.py 12.4 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
61
    FLASH_ATTENTION = False
62

63
if FLASH_ATTENTION:
64
    __all__.append(FlashNeoXSharded)
65
    __all__.append(FlashRWSharded)
66
    __all__.append(FlashSantacoderSharded)
67
    __all__.append(FlashLlama)
68
    __all__.append(IDEFICSSharded)
69

70
71
72
73
74
75
76
77
78
79
MISTRAL = True
try:
    from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
    logger.warning(f"Could not import Mistral model: {e}")
    MISTRAL = False

if MISTRAL:
    __all__.append(FlashMistral)

OlivierDehaene's avatar
OlivierDehaene committed
80
81
82
83
84
85
86
87
88
89
90
MIXTRAL = True
try:
    from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
    logger.warning(f"Could not import Mixtral model: {e}")
    MIXTRAL = False

if MIXTRAL:
    __all__.append(FlashMixtral)


91
def get_model(
92
93
94
95
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
96
    speculate: Optional[int],
97
    dtype: Optional[str],
98
    trust_remote_code: bool,
99
) -> Model:
100
    if dtype is None:
101
102
103
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
104
105
106
107
108
109
110
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
111
112
113
114
115
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

116
    if "facebook/galactica" in model_id:
117
        return GalacticaSharded(
118
119
120
121
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
122
            trust_remote_code=trust_remote_code,
123
        )
124

125
    if model_id.startswith("bigcode/"):
126
        if FLASH_ATTENTION:
127
128
129
130
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
131
                dtype=dtype,
132
133
                trust_remote_code=trust_remote_code,
            )
134
135
136
137
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
138
        else:
139
            return SantaCoder(
140
141
142
                model_id,
                revision,
                quantize=quantize,
143
                dtype=dtype,
144
145
                trust_remote_code=trust_remote_code,
            )
146

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
147
148
149
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
150
151
152
153
154
155
156
157
158

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
159
160
161
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

178
    model_type = config_dict["model_type"]
179

180
    if model_type == "gpt_bigcode":
181
        if FLASH_ATTENTION:
182
183
184
185
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
186
                dtype=dtype,
187
188
                trust_remote_code=trust_remote_code,
            )
189
190
191
192
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
193
        else:
194
            return SantaCoder(
195
196
197
                model_id,
                revision,
                quantize=quantize,
198
                dtype=dtype,
199
200
                trust_remote_code=trust_remote_code,
            )
201

202
    if model_type == "bloom":
203
        return BLOOMSharded(
204
205
206
207
208
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
209
        )
210
211
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
212
213
214
215
216
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
217
        )
218
219
220
221
222
223
224

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
225
                dtype=dtype,
226
227
228
229
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
230
231
232
                model_id,
                revision,
                quantize=quantize,
233
                dtype=dtype,
234
235
                trust_remote_code=trust_remote_code,
            )
236
        else:
237
            return CausalLM(
238
239
240
                model_id,
                revision,
                quantize=quantize,
241
                dtype=dtype,
242
243
                trust_remote_code=trust_remote_code,
            )
244

xiaobin's avatar
xiaobin committed
245
    elif model_type == "llama" or model_type == "baichuan":
246
247
        if FLASH_ATTENTION:
            return FlashLlama(
248
249
250
                model_id,
                revision,
                quantize=quantize,
251
                dtype=dtype,
252
                trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
253
                use_medusa=use_medusa,
254
            )
255
256
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
257
        else:
258
            return CausalLM(
259
260
261
                model_id,
                revision,
                quantize=quantize,
262
                dtype=dtype,
263
264
                trust_remote_code=trust_remote_code,
            )
265

266
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
267
268
        if sharded:
            if FLASH_ATTENTION:
269
                if config_dict.get("alibi", False):
270
271
272
273
274
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
275
                    dtype=dtype,
276
277
                    trust_remote_code=trust_remote_code,
                )
278
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
279
        else:
280
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
281
                return FlashRWSharded(
282
283
284
                    model_id,
                    revision,
                    quantize=quantize,
285
                    dtype=dtype,
286
287
288
289
290
291
292
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
293
                    dtype=dtype,
294
295
296
                    trust_remote_code=trust_remote_code,
                )

297
298
299
300
301
302
303
304
305
    if model_type == "mistral":
        if MISTRAL:
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
306
307
308
309
310
311
312
313
314
315
316
        raise NotImplementedError("Mistral models requires flash attention v2")

    if model_type == "mixtral":
        if MIXTRAL:
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
317
318
319
        raise NotImplementedError(
            "Mixtral models requires flash attention v2, stk and megablocks"
        )
320
321

    if model_type == "opt":
322
        return OPTSharded(
323
324
325
326
327
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
328
        )
329

330
    if model_type == "t5":
331
332
333
334
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
335
            dtype=dtype,
336
337
            trust_remote_code=trust_remote_code,
        )
338
    if model_type == "idefics":
339
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
340
341
342
343
344
345
346
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
347
348
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
349
350
351

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
352
353
354
355
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
356
    if quantize == "awq":
OlivierDehaene's avatar
OlivierDehaene committed
357
        raise ValueError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
358
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
OlivierDehaene's avatar
OlivierDehaene committed
359
        raise ValueError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
360
    elif quantize == "eetq":
361
        raise ValueError("Eetq quantization is not supported for AutoModel")
362
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
363
        return CausalLM(
364
365
366
367
368
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
369
        )
370
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
371
        return Seq2SeqLM(
372
373
374
375
376
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
377
378
        )

379
    auto_map = config_dict.get("auto_map", None)
380
381
382
383
384
385
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
386
                dtype=dtype,
387
388
                trust_remote_code=trust_remote_code,
            )
389
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
390
391
392
393
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
394
                dtype=dtype,
395
396
                trust_remote_code=trust_remote_code,
            )
397
398

    raise ValueError(f"Unsupported model type {model_type}")