__init__.py 20.5 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
71
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
72
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
73
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
74
    from text_generation_server.models.flash_dbrx import FlashDbrx
75
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
76
77
78

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
79
    FLASH_ATTENTION = False
80
    HAS_FLASH_ATTN_V2_CUDA = False
81

82
if FLASH_ATTENTION:
83
    __all__.append(FlashNeoXSharded)
84
    __all__.append(FlashRWSharded)
85
    __all__.append(FlashSantacoderSharded)
86
    __all__.append(FlashLlama)
87
    __all__.append(IDEFICSSharded)
88
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
89
    __all__.append(FlashMixtral)
90
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
91
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
92
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
93
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
94
95
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
96

drbh's avatar
drbh committed
97
98
99
100
101
102
103
104
105
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
106

107

108
def get_model(
109
110
111
112
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: Optional[int],
114
    dtype: Optional[str],
115
    trust_remote_code: bool,
116
) -> Model:
117
    if dtype is None:
118
119
120
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
121
122
123
124
125
126
127
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
131
132
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
133
134
135
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
136
137
138

    use_medusa = None
    if "medusa_num_heads" in config_dict:
139
140
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
141
142
143
144
145
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
146
147
148
                raise RuntimeError(
                    "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
                )
Nicolas Patry's avatar
Nicolas Patry committed
149
150
151
152
153
154
155
156
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
171
172
173
174
175
176
177
178
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
179
180
181
182
183
184
185
186
187
188
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
189
190
191
192
193
194
195
196
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
197
198
199
200
201
202

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
203
            use_medusa=use_medusa,
drbh's avatar
drbh committed
204
205
206
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
207

OlivierDehaene's avatar
OlivierDehaene committed
208
209
210
211
212
213
214
215
216
217
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

218
219
220
221
222
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
223
        if FLASH_ATTENTION:
224
225
226
227
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
228
                use_medusa=use_medusa,
229
                dtype=dtype,
230
231
                trust_remote_code=trust_remote_code,
            )
232
233
234
235
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
236
        else:
237
            return SantaCoder(
238
239
240
                model_id,
                revision,
                quantize=quantize,
241
                use_medusa=use_medusa,
242
                dtype=dtype,
243
244
                trust_remote_code=trust_remote_code,
            )
245

246
    if model_type == "bloom":
247
        return BLOOMSharded(
248
249
250
            model_id,
            revision,
            quantize=quantize,
251
            use_medusa=use_medusa,
252
253
            dtype=dtype,
            trust_remote_code=trust_remote_code,
254
        )
255
256
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
            model_id,
            revision,
            quantize=quantize,
260
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
261
262
            dtype=dtype,
            trust_remote_code=trust_remote_code,
263
        )
264
265
266
267
268
269
270

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
271
                use_medusa=use_medusa,
272
                dtype=dtype,
273
274
275
276
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
277
278
279
                model_id,
                revision,
                quantize=quantize,
280
                use_medusa=use_medusa,
281
                dtype=dtype,
282
283
                trust_remote_code=trust_remote_code,
            )
284
        else:
285
            return CausalLM(
286
287
288
                model_id,
                revision,
                quantize=quantize,
289
                use_medusa=use_medusa,
290
                dtype=dtype,
291
292
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
293

drbh's avatar
drbh committed
294
295
296
297
298
299
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
300
                use_medusa=use_medusa,
drbh's avatar
drbh committed
301
302
303
304
305
306
307
308
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
309
                use_medusa=use_medusa,
drbh's avatar
drbh committed
310
311
312
313
314
315
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
316
317
318
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
319
320
321
322
323
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
324
                use_medusa=use_medusa,
drbh's avatar
drbh committed
325
326
327
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
328

xiaobin's avatar
xiaobin committed
329
    elif model_type == "llama" or model_type == "baichuan":
330
331
        if FLASH_ATTENTION:
            return FlashLlama(
332
333
334
                model_id,
                revision,
                quantize=quantize,
335
                use_medusa=use_medusa,
336
                dtype=dtype,
337
338
                trust_remote_code=trust_remote_code,
            )
339
340
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
341
        else:
342
            return CausalLM(
343
344
345
                model_id,
                revision,
                quantize=quantize,
346
                use_medusa=use_medusa,
347
                dtype=dtype,
348
349
                trust_remote_code=trust_remote_code,
            )
350
351
352
353
354
355
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
356
                use_medusa=use_medusa,
357
358
359
360
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
361
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
362
363
364
365
366
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
367
                use_medusa=use_medusa,
368
369
370
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
371

OlivierDehaene's avatar
OlivierDehaene committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

416
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
417
418
        if sharded:
            if FLASH_ATTENTION:
419
                if config_dict.get("alibi", False):
420
421
422
423
424
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
425
                    use_medusa=use_medusa,
426
                    dtype=dtype,
427
428
                    trust_remote_code=trust_remote_code,
                )
429
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
430
        else:
431
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
432
                return FlashRWSharded(
433
434
435
                    model_id,
                    revision,
                    quantize=quantize,
436
                    use_medusa=use_medusa,
437
                    dtype=dtype,
438
439
440
441
442
443
444
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
445
                    use_medusa=use_medusa,
446
                    dtype=dtype,
447
448
449
                    trust_remote_code=trust_remote_code,
                )

450
    if model_type == "mistral":
451
452
453
454
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
455
456
457
458
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
459
                use_medusa=use_medusa,
460
461
462
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
463
464
465
466
467
468
469
470
471
472
473
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
474
475

    if model_type == "mixtral":
476
477
478
479
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
480
481
482
483
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
484
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
485
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
486
487
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
488
489
490
491
492
493
494
495
496
497
498
499
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
500
501
502
503
504
505
506
507
508
509
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
547
548
                trust_remote_code=trust_remote_code,
            )
549
550

    if model_type == "opt":
551
        return OPTSharded(
552
553
554
            model_id,
            revision,
            quantize=quantize,
555
            use_medusa=use_medusa,
556
557
            dtype=dtype,
            trust_remote_code=trust_remote_code,
558
        )
559

560
    if model_type == "t5":
561
562
563
564
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
565
            use_medusa=use_medusa,
566
            dtype=dtype,
567
568
            trust_remote_code=trust_remote_code,
        )
569
    if model_type == "idefics":
570
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
571
572
573
574
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
575
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
576
577
578
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
579
580
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
581
582

    if sharded:
583
        raise NotImplementedError("sharded is not supported for AutoModel")
584
    if quantize == "gptq":
585
        raise NotImplementedError(
586
587
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
588
    if quantize == "awq":
589
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
590
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
591
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
592
    elif quantize == "eetq":
593
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
594
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
595
        return CausalLM(
596
597
598
            model_id,
            revision,
            quantize=quantize,
599
            use_medusa=use_medusa,
600
601
            dtype=dtype,
            trust_remote_code=trust_remote_code,
602
        )
603
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
604
        return Seq2SeqLM(
605
606
607
            model_id,
            revision,
            quantize=quantize,
608
            use_medusa=use_medusa,
609
610
            dtype=dtype,
            trust_remote_code=trust_remote_code,
611
612
        )

613
    auto_map = config_dict.get("auto_map", None)
614
615
616
617
618
619
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
620
                use_medusa=use_medusa,
621
                dtype=dtype,
622
623
                trust_remote_code=trust_remote_code,
            )
624
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
625
626
627
628
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
629
                use_medusa=use_medusa,
630
                dtype=dtype,
631
632
                trust_remote_code=trust_remote_code,
            )
633
634

    raise ValueError(f"Unsupported model type {model_type}")