__init__.py 19.3 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
71
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
72
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
73
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
74
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
75
76
77

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
78
    FLASH_ATTENTION = False
79
    HAS_FLASH_ATTN_V2_CUDA = False
80

81
if FLASH_ATTENTION:
82
    __all__.append(FlashNeoXSharded)
83
    __all__.append(FlashRWSharded)
84
    __all__.append(FlashSantacoderSharded)
85
    __all__.append(FlashLlama)
86
    __all__.append(IDEFICSSharded)
87
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
88
    __all__.append(FlashMixtral)
drbh's avatar
drbh committed
89
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
90
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
91
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
92
93
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
94

drbh's avatar
drbh committed
95
96
97
98
99
100
101
102
103
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
104

105

106
def get_model(
107
108
109
110
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
111
    speculate: Optional[int],
112
    dtype: Optional[str],
113
    trust_remote_code: bool,
114
) -> Model:
115
    if dtype is None:
116
117
118
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
119
120
121
122
123
124
125
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
126
127
128
129
130
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
131
132
133
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
134
135
136

    use_medusa = None
    if "medusa_num_heads" in config_dict:
137
138
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
139
140
141
142
143
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
144
145
146
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
147
148
149
150
151
152
153
154
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
169
170
171
172
173
174
175
176
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
193
            use_medusa=use_medusa,
drbh's avatar
drbh committed
194
195
196
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
197

OlivierDehaene's avatar
OlivierDehaene committed
198
199
200
201
202
203
204
205
206
207
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

208
209
210
211
212
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
213
        if FLASH_ATTENTION:
214
215
216
217
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
218
                use_medusa=use_medusa,
219
                dtype=dtype,
220
221
                trust_remote_code=trust_remote_code,
            )
222
223
224
225
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
226
        else:
227
            return SantaCoder(
228
229
230
                model_id,
                revision,
                quantize=quantize,
231
                use_medusa=use_medusa,
232
                dtype=dtype,
233
234
                trust_remote_code=trust_remote_code,
            )
235

236
    if model_type == "bloom":
237
        return BLOOMSharded(
238
239
240
            model_id,
            revision,
            quantize=quantize,
241
            use_medusa=use_medusa,
242
243
            dtype=dtype,
            trust_remote_code=trust_remote_code,
244
        )
245
246
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
247
248
249
            model_id,
            revision,
            quantize=quantize,
250
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
251
252
            dtype=dtype,
            trust_remote_code=trust_remote_code,
253
        )
254
255
256
257
258
259
260

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
261
                use_medusa=use_medusa,
262
                dtype=dtype,
263
264
265
266
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
267
268
269
                model_id,
                revision,
                quantize=quantize,
270
                use_medusa=use_medusa,
271
                dtype=dtype,
272
273
                trust_remote_code=trust_remote_code,
            )
274
        else:
275
            return CausalLM(
276
277
278
                model_id,
                revision,
                quantize=quantize,
279
                use_medusa=use_medusa,
280
                dtype=dtype,
281
282
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
283

drbh's avatar
drbh committed
284
285
286
287
288
289
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
290
                use_medusa=use_medusa,
drbh's avatar
drbh committed
291
292
293
294
295
296
297
298
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
299
                use_medusa=use_medusa,
drbh's avatar
drbh committed
300
301
302
303
304
305
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
306
307
308
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
309
310
311
312
313
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
314
                use_medusa=use_medusa,
drbh's avatar
drbh committed
315
316
317
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
318

xiaobin's avatar
xiaobin committed
319
    elif model_type == "llama" or model_type == "baichuan":
320
321
        if FLASH_ATTENTION:
            return FlashLlama(
322
323
324
                model_id,
                revision,
                quantize=quantize,
325
                use_medusa=use_medusa,
326
                dtype=dtype,
327
328
                trust_remote_code=trust_remote_code,
            )
329
330
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
331
        else:
332
            return CausalLM(
333
334
335
                model_id,
                revision,
                quantize=quantize,
336
                use_medusa=use_medusa,
337
                dtype=dtype,
338
339
                trust_remote_code=trust_remote_code,
            )
340
341
342
343
344
345
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
346
                use_medusa=use_medusa,
347
348
349
350
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
351
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
352
353
354
355
356
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
357
                use_medusa=use_medusa,
358
359
360
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
361

OlivierDehaene's avatar
OlivierDehaene committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

384
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
385
386
        if sharded:
            if FLASH_ATTENTION:
387
                if config_dict.get("alibi", False):
388
389
390
391
392
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
393
                    use_medusa=use_medusa,
394
                    dtype=dtype,
395
396
                    trust_remote_code=trust_remote_code,
                )
397
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
398
        else:
399
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
400
                return FlashRWSharded(
401
402
403
                    model_id,
                    revision,
                    quantize=quantize,
404
                    use_medusa=use_medusa,
405
                    dtype=dtype,
406
407
408
409
410
411
412
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
413
                    use_medusa=use_medusa,
414
                    dtype=dtype,
415
416
417
                    trust_remote_code=trust_remote_code,
                )

418
    if model_type == "mistral":
419
420
421
422
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
423
424
425
426
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
427
                use_medusa=use_medusa,
428
429
430
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
431
432
433
434
435
436
437
438
439
440
441
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
442
443

    if model_type == "mixtral":
444
445
446
447
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
448
449
450
451
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
452
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
453
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
454
455
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
456
457
458
459
460
461
462
463
464
465
466
467
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
468
469
470
471
472
473
474
475
476
477
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
515
516
                trust_remote_code=trust_remote_code,
            )
517
518

    if model_type == "opt":
519
        return OPTSharded(
520
521
522
            model_id,
            revision,
            quantize=quantize,
523
            use_medusa=use_medusa,
524
525
            dtype=dtype,
            trust_remote_code=trust_remote_code,
526
        )
527

528
    if model_type == "t5":
529
530
531
532
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
533
            use_medusa=use_medusa,
534
            dtype=dtype,
535
536
            trust_remote_code=trust_remote_code,
        )
537
    if model_type == "idefics":
538
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
539
540
541
542
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
543
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
544
545
546
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
547
548
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
549
550

    if sharded:
551
        raise NotImplementedError("sharded is not supported for AutoModel")
552
    if quantize == "gptq":
553
        raise NotImplementedError(
554
555
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
556
    if quantize == "awq":
557
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
558
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
559
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
560
    elif quantize == "eetq":
561
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
562
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
563
        return CausalLM(
564
565
566
            model_id,
            revision,
            quantize=quantize,
567
            use_medusa=use_medusa,
568
569
            dtype=dtype,
            trust_remote_code=trust_remote_code,
570
        )
571
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
572
        return Seq2SeqLM(
573
574
575
            model_id,
            revision,
            quantize=quantize,
576
            use_medusa=use_medusa,
577
578
            dtype=dtype,
            trust_remote_code=trust_remote_code,
579
580
        )

581
    auto_map = config_dict.get("auto_map", None)
582
583
584
585
586
587
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
588
                use_medusa=use_medusa,
589
                dtype=dtype,
590
591
                trust_remote_code=trust_remote_code,
            )
592
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
593
594
595
596
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
597
                use_medusa=use_medusa,
598
                dtype=dtype,
599
600
                trust_remote_code=trust_remote_code,
            )
601
602

    raise ValueError(f"Unsupported model type {model_type}")