__init__.py 15.7 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "FlashCausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
57
58
59
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
60
61
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
62
    )
63
    from text_generation_server.models.idefics import IDEFICSSharded
64
65
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
66
    from text_generation_server.models.flash_phi import FlashPhi
67
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
68
69
70

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
71
    FLASH_ATTENTION = False
72
    HAS_FLASH_ATTN_V2_CUDA = False
73

74
if FLASH_ATTENTION:
75
    __all__.append(FlashNeoXSharded)
76
    __all__.append(FlashRWSharded)
77
    __all__.append(FlashSantacoderSharded)
78
    __all__.append(FlashLlama)
79
    __all__.append(IDEFICSSharded)
80
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
81
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
82
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
83

drbh's avatar
drbh committed
84
85
86
87
88
89
90
91
92
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
93

94

95
def get_model(
96
97
98
99
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
100
    speculate: Optional[int],
101
    dtype: Optional[str],
102
    trust_remote_code: bool,
103
) -> Model:
104
    if dtype is None:
105
106
107
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
108
109
110
111
112
113
114
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
115
116
117
118
119
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
120
121
122
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
123
124
125

    use_medusa = None
    if "medusa_num_heads" in config_dict:
126
127
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
131
132
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
133
134
135
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
136
137
138
139
140
141
142
143
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
158
159
160
161
162
163
164
165
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
182
            use_medusa=use_medusa,
drbh's avatar
drbh committed
183
184
185
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
186

187
188
189
190
191
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
192
        if FLASH_ATTENTION:
193
194
195
196
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
197
                use_medusa=use_medusa,
198
                dtype=dtype,
199
200
                trust_remote_code=trust_remote_code,
            )
201
202
203
204
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
205
        else:
206
            return SantaCoder(
207
208
209
                model_id,
                revision,
                quantize=quantize,
210
                use_medusa=use_medusa,
211
                dtype=dtype,
212
213
                trust_remote_code=trust_remote_code,
            )
214

215
    if model_type == "bloom":
216
        return BLOOMSharded(
217
218
219
            model_id,
            revision,
            quantize=quantize,
220
            use_medusa=use_medusa,
221
222
            dtype=dtype,
            trust_remote_code=trust_remote_code,
223
        )
224
225
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
226
227
228
            model_id,
            revision,
            quantize=quantize,
229
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
230
231
            dtype=dtype,
            trust_remote_code=trust_remote_code,
232
        )
233
234
235
236
237
238
239

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
240
                use_medusa=use_medusa,
241
                dtype=dtype,
242
243
244
245
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
246
247
248
                model_id,
                revision,
                quantize=quantize,
249
                use_medusa=use_medusa,
250
                dtype=dtype,
251
252
                trust_remote_code=trust_remote_code,
            )
253
        else:
254
            return CausalLM(
255
256
257
                model_id,
                revision,
                quantize=quantize,
258
                use_medusa=use_medusa,
259
                dtype=dtype,
260
261
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
262

drbh's avatar
drbh committed
263
264
265
266
267
268
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
269
                use_medusa=use_medusa,
drbh's avatar
drbh committed
270
271
272
273
274
275
276
277
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
278
                use_medusa=use_medusa,
drbh's avatar
drbh committed
279
280
281
282
283
284
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
285
286
287
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
288
289
290
291
292
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
293
                use_medusa=use_medusa,
drbh's avatar
drbh committed
294
295
296
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
297

xiaobin's avatar
xiaobin committed
298
    elif model_type == "llama" or model_type == "baichuan":
299
300
        if FLASH_ATTENTION:
            return FlashLlama(
301
302
303
                model_id,
                revision,
                quantize=quantize,
304
                use_medusa=use_medusa,
305
                dtype=dtype,
306
307
                trust_remote_code=trust_remote_code,
            )
308
309
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
310
        else:
311
            return CausalLM(
312
313
314
                model_id,
                revision,
                quantize=quantize,
315
                use_medusa=use_medusa,
316
                dtype=dtype,
317
318
                trust_remote_code=trust_remote_code,
            )
319
320
321
322
323
324
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
325
                use_medusa=use_medusa,
326
327
328
329
330
331
332
333
334
335
336
337
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
338
                use_medusa=use_medusa,
339
340
341
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
342

343
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
344
345
        if sharded:
            if FLASH_ATTENTION:
346
                if config_dict.get("alibi", False):
347
348
349
350
351
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
352
                    use_medusa=use_medusa,
353
                    dtype=dtype,
354
355
                    trust_remote_code=trust_remote_code,
                )
356
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
357
        else:
358
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
359
                return FlashRWSharded(
360
361
362
                    model_id,
                    revision,
                    quantize=quantize,
363
                    use_medusa=use_medusa,
364
                    dtype=dtype,
365
366
367
368
369
370
371
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
372
                    use_medusa=use_medusa,
373
                    dtype=dtype,
374
375
376
                    trust_remote_code=trust_remote_code,
                )

377
    if model_type == "mistral":
378
379
380
381
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
382
383
384
385
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
386
                use_medusa=use_medusa,
387
388
389
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
390
391

    if model_type == "mixtral":
392
393
394
395
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
396
397
398
399
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
400
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
401
402
403
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
404
405

    if model_type == "opt":
406
        return OPTSharded(
407
408
409
            model_id,
            revision,
            quantize=quantize,
410
            use_medusa=use_medusa,
411
412
            dtype=dtype,
            trust_remote_code=trust_remote_code,
413
        )
414

415
    if model_type == "t5":
416
417
418
419
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
420
            use_medusa=use_medusa,
421
            dtype=dtype,
422
423
            trust_remote_code=trust_remote_code,
        )
424
    if model_type == "idefics":
425
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
426
427
428
429
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
430
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
431
432
433
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
434
435
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
436
437

    if sharded:
438
        raise NotImplementedError("sharded is not supported for AutoModel")
439
    if quantize == "gptq":
440
        raise NotImplementedError(
441
442
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
443
    if quantize == "awq":
444
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
445
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
446
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
447
    elif quantize == "eetq":
448
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
449
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
450
        return CausalLM(
451
452
453
            model_id,
            revision,
            quantize=quantize,
454
            use_medusa=use_medusa,
455
456
            dtype=dtype,
            trust_remote_code=trust_remote_code,
457
        )
458
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
459
        return Seq2SeqLM(
460
461
462
            model_id,
            revision,
            quantize=quantize,
463
            use_medusa=use_medusa,
464
465
            dtype=dtype,
            trust_remote_code=trust_remote_code,
466
467
        )

468
    auto_map = config_dict.get("auto_map", None)
469
470
471
472
473
474
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
475
                use_medusa=use_medusa,
476
                dtype=dtype,
477
478
                trust_remote_code=trust_remote_code,
            )
479
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
480
481
482
483
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
484
                use_medusa=use_medusa,
485
                dtype=dtype,
486
487
                trust_remote_code=trust_remote_code,
            )
488
489

    raise ValueError(f"Unsupported model type {model_type}")