__init__.py 12.3 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
7
from typing import Optional

Nicolas Patry's avatar
Nicolas Patry committed
8
from text_generation_server.utils.speculate import get_speculate, set_speculate
9
10
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
11
from text_generation_server.models.flash_causal_lm import FlashCausalLM
12
from text_generation_server.models.bloom import BLOOMSharded
13
from text_generation_server.models.mpt import MPTSharded
14
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
15
from text_generation_server.models.rw import RW
16
17
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
18
19
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
20
from text_generation_server.models.gpt_neox import GPTNeoxSharded
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

45
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
46

47
FLASH_ATTENTION = True
48
try:
49
50
51
52
53
54
55
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
56
    )
57
    from text_generation_server.models.idefics import IDEFICSSharded
58
59
60

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
61
    FLASH_ATTENTION = False
62

63
if FLASH_ATTENTION:
64
    __all__.append(FlashNeoXSharded)
65
    __all__.append(FlashRWSharded)
66
    __all__.append(FlashSantacoderSharded)
67
    __all__.append(FlashLlama)
68
    __all__.append(IDEFICSSharded)
69

70
71
72
73
74
75
76
77
78
79
MISTRAL = True
try:
    from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
    logger.warning(f"Could not import Mistral model: {e}")
    MISTRAL = False

if MISTRAL:
    __all__.append(FlashMistral)

OlivierDehaene's avatar
OlivierDehaene committed
80
81
82
83
84
85
86
87
88
89
90
91
MIXTRAL = True
try:
    from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
    logger.warning(f"Could not import Mixtral model: {e}")
    MIXTRAL = False

if MIXTRAL:
    __all__.append(FlashMixtral)



92
def get_model(
93
94
95
96
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
97
    speculate: Optional[int],
98
    dtype: Optional[str],
99
    trust_remote_code: bool,
100
) -> Model:
101
    if dtype is None:
102
103
104
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
105
106
107
108
109
110
111
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
112
113
114
115
116
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

117
    if "facebook/galactica" in model_id:
118
        return GalacticaSharded(
119
120
121
122
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
123
            trust_remote_code=trust_remote_code,
124
        )
125

126
    if model_id.startswith("bigcode/"):
127
        if FLASH_ATTENTION:
128
129
130
131
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
132
                dtype=dtype,
133
134
                trust_remote_code=trust_remote_code,
            )
135
136
137
138
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
139
        else:
140
            return SantaCoder(
141
142
143
                model_id,
                revision,
                quantize=quantize,
144
                dtype=dtype,
145
146
                trust_remote_code=trust_remote_code,
            )
147

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
148
149
150
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    use_medusa = None
    if "medusa_num_heads" in config_dict:
        use_medusa = model_id
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
                raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

177
    model_type = config_dict["model_type"]
178

179
    if model_type == "gpt_bigcode":
180
        if FLASH_ATTENTION:
181
182
183
184
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
185
                dtype=dtype,
186
187
                trust_remote_code=trust_remote_code,
            )
188
189
190
191
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
192
        else:
193
            return SantaCoder(
194
195
196
                model_id,
                revision,
                quantize=quantize,
197
                dtype=dtype,
198
199
                trust_remote_code=trust_remote_code,
            )
200

201
    if model_type == "bloom":
202
        return BLOOMSharded(
203
204
205
206
207
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
208
        )
209
210
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
211
212
213
214
215
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
216
        )
217
218
219
220
221
222
223

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
224
                dtype=dtype,
225
226
227
228
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
229
230
231
                model_id,
                revision,
                quantize=quantize,
232
                dtype=dtype,
233
234
                trust_remote_code=trust_remote_code,
            )
235
        else:
236
            return CausalLM(
237
238
239
                model_id,
                revision,
                quantize=quantize,
240
                dtype=dtype,
241
242
                trust_remote_code=trust_remote_code,
            )
243

xiaobin's avatar
xiaobin committed
244
    elif model_type == "llama" or model_type == "baichuan":
245
246
        if FLASH_ATTENTION:
            return FlashLlama(
247
248
249
                model_id,
                revision,
                quantize=quantize,
250
                dtype=dtype,
251
                trust_remote_code=trust_remote_code,
Nicolas Patry's avatar
Nicolas Patry committed
252
                use_medusa=use_medusa
253
            )
254
255
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
256
        else:
257
            return CausalLM(
258
259
260
                model_id,
                revision,
                quantize=quantize,
261
                dtype=dtype,
262
263
                trust_remote_code=trust_remote_code,
            )
264

265
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
266
267
        if sharded:
            if FLASH_ATTENTION:
268
                if config_dict.get("alibi", False):
269
270
271
272
273
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
274
                    dtype=dtype,
275
276
                    trust_remote_code=trust_remote_code,
                )
277
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
278
        else:
279
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
280
                return FlashRWSharded(
281
282
283
                    model_id,
                    revision,
                    quantize=quantize,
284
                    dtype=dtype,
285
286
287
288
289
290
291
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
292
                    dtype=dtype,
293
294
295
                    trust_remote_code=trust_remote_code,
                )

296
297
298
299
300
301
302
303
304
    if model_type == "mistral":
        if MISTRAL:
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
305
306
307
308
309
310
311
312
313
314
315
316
        raise NotImplementedError("Mistral models requires flash attention v2")

    if model_type == "mixtral":
        if MIXTRAL:
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
317
318

    if model_type == "opt":
319
        return OPTSharded(
320
321
322
323
324
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
325
        )
326

327
    if model_type == "t5":
328
329
330
331
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
332
            dtype=dtype,
333
334
            trust_remote_code=trust_remote_code,
        )
335
    if model_type == "idefics":
336
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
337
338
339
340
341
342
343
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
344
345
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
346
347
348

    if sharded:
        raise ValueError("sharded is not supported for AutoModel")
349
350
351
352
    if quantize == "gptq":
        raise ValueError(
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
353
    if quantize == "awq":
OlivierDehaene's avatar
OlivierDehaene committed
354
        raise ValueError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
355
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
OlivierDehaene's avatar
OlivierDehaene committed
356
        raise ValueError("4bit quantization is not supported for AutoModel")
357
358
    elif (quantize == "eetq"):
        raise ValueError("Eetq quantization is not supported for AutoModel")
359
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
360
        return CausalLM(
361
362
363
364
365
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
366
        )
367
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
368
        return Seq2SeqLM(
369
370
371
372
373
            model_id,
            revision,
            quantize=quantize,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
374
375
        )

376
    auto_map = config_dict.get("auto_map", None)
377
378
379
380
381
382
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
383
                dtype=dtype,
384
385
                trust_remote_code=trust_remote_code,
            )
386
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
387
388
389
390
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
391
                dtype=dtype,
392
393
                trust_remote_code=trust_remote_code,
            )
394
395

    raise ValueError(f"Unsupported model type {model_type}")