__init__.py 29.9 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
14
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
15
from text_generation_server.models.flash_causal_lm import FlashCausalLM
16
from text_generation_server.models.bloom import BLOOMSharded
17
from text_generation_server.models.mpt import MPTSharded
18
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
19
from text_generation_server.models.rw import RW
20
21
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
22
23
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
24
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
25
from text_generation_server.models.phi import Phi
26

27
28
from text_generation_server.utils.import_utils import SYSTEM

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

51
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
52

53
FLASH_ATTENTION = True
54

55
try:
56
    from text_generation_server.models.flash_rw import FlashRWSharded
57
    from text_generation_server.models.flash_gpt2 import FlashGPT2
58
59
60
61
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
65
66
67
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
68
69
70
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
drbh's avatar
drbh committed
71
72
73
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
74
75
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
76
    )
77
    from text_generation_server.models.idefics import IDEFICSSharded
78
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
79
    from text_generation_server.models.idefics2 import Idefics2
80
81
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
82
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
83
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
84
    from text_generation_server.models.flash_dbrx import FlashDbrx
85
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
86
87
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
88
    SUPPORTS_WINDOWING = False
89
    FLASH_ATTENTION = False
90

91
if FLASH_ATTENTION:
92
    __all__.append(FlashGPT2)
93
    __all__.append(FlashNeoXSharded)
94
    __all__.append(FlashRWSharded)
95
    __all__.append(FlashSantacoderSharded)
96
    __all__.append(FlashLlama)
97
    __all__.append(IDEFICSSharded)
98
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
99
    __all__.append(FlashMixtral)
100
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
101
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
102
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
103
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
104
105
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
106

drbh's avatar
drbh committed
107
108
109
110
111
112
113
114
115
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
116

117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
199
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


254
def get_model(
255
    model_id: str,
drbh's avatar
drbh committed
256
    lora_adapter_ids: Optional[List[str]],
257
258
259
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
260
    speculate: Optional[int],
261
    dtype: Optional[str],
262
    trust_remote_code: bool,
263
    max_input_tokens: int,
264
) -> Model:
265
    global FLASH_ATTENTION
266
    if dtype is None:
267
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
268
269
270
271
272
273
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
274
275
276
277
278
279
280
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
281
282
283
284
285
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
286
287
288
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
289
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
290

Nicolas Patry's avatar
Nicolas Patry committed
291
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
292
    if "medusa_num_heads" in config_dict:
293
294
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
295
296
297
298
299
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
300
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
301
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
302
                )
Nicolas Patry's avatar
Nicolas Patry committed
303
304
305
306
307
308
309
310
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
311
312
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
313
314
315
316
317
318
319
320
321
322
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
323
324
325
326
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
327
        else:
Nicolas Patry's avatar
Nicolas Patry committed
328
329
330
331
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
332

Nicolas Patry's avatar
Nicolas Patry committed
333
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
387
388
389
390
391
392
393
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
394
395
396
397
398
399
400
401
402
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
403
404
405
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
406
        if method in {"gptq", "awq", "exl2"}:
407
408
409
410
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
411

412
413
414
415
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
416
    sliding_window = config_dict.get("sliding_window", -1)
417
418
419
420
421
422
423
424

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
425
        )
426

427
    if model_type == MAMBA:
drbh's avatar
drbh committed
428
429
430
431
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
432
            speculator=speculator,
drbh's avatar
drbh committed
433
434
435
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
436

OlivierDehaene's avatar
OlivierDehaene committed
437
438
439
440
441
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
442
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
443
444
445
446
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

447
    if (
448
449
        model_type == GPT_BIGCODE
        or model_type == GPT2
450
451
        and model_id.startswith("bigcode/")
    ):
452
        if FLASH_ATTENTION:
453
454
455
456
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
457
                speculator=speculator,
458
                dtype=dtype,
459
460
                trust_remote_code=trust_remote_code,
            )
461
462
463
464
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
465
        else:
466
            return SantaCoder(
467
468
469
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
470
                speculator=speculator,
471
                dtype=dtype,
472
473
                trust_remote_code=trust_remote_code,
            )
474

475
    if model_type == BLOOM:
476
        return BLOOMSharded(
477
478
479
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
480
            speculator=speculator,
481
482
            dtype=dtype,
            trust_remote_code=trust_remote_code,
483
        )
484
    elif model_type == MPT:
485
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
486
487
488
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
489
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
490
491
            dtype=dtype,
            trust_remote_code=trust_remote_code,
492
        )
493
    elif model_type == GPT2:
494
        if FLASH_ATTENTION:
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            try:
                return FlashGPT2(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
                return CausalLM(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
515
516
517
518
519
520
521
522
523
524
525
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
526
    elif model_type == GPT_NEOX:
527
528
529
530
531
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
532
                speculator=speculator,
533
                dtype=dtype,
534
535
536
537
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
538
539
540
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
541
                speculator=speculator,
542
                dtype=dtype,
543
544
                trust_remote_code=trust_remote_code,
            )
545
        else:
546
            return CausalLM(
547
548
549
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
550
                speculator=speculator,
551
                dtype=dtype,
552
553
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
554

555
    elif model_type == PHI:
drbh's avatar
drbh committed
556
557
558
559
560
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
561
                speculator=speculator,
drbh's avatar
drbh committed
562
563
564
565
566
567
568
569
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
570
                speculator=speculator,
drbh's avatar
drbh committed
571
572
573
574
575
576
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
577
578
579
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
580
581
582
583
584
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
585
                speculator=speculator,
drbh's avatar
drbh committed
586
587
588
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
589

590
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
591
592
        if FLASH_ATTENTION:
            return FlashLlama(
593
594
595
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
596
                speculator=speculator,
597
                dtype=dtype,
598
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
599
                lora_adapter_ids=lora_adapter_ids,
600
            )
601
602
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
603
        else:
604
            return CausalLM(
605
606
607
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
608
                speculator=speculator,
609
                dtype=dtype,
610
611
                trust_remote_code=trust_remote_code,
            )
612
    if model_type == GEMMA:
613
614
615
616
617
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
618
                speculator=speculator,
619
620
621
622
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
623
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
624
625
626
627
628
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
629
                speculator=speculator,
630
631
632
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
633

634
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
635
636
637
638
639
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
640
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
641
642
643
644
645
646
647
648
649
650
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
651
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
652
653
654
655
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

656
    if model_type == DBRX:
657
658
659
660
661
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
662
                speculator=speculator,
663
664
665
666
667
668
669
670
671
672
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
673
                speculator=speculator,
674
675
676
677
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

678
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
679
680
        if sharded:
            if FLASH_ATTENTION:
681
                if config_dict.get("alibi", False):
682
683
684
685
686
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
687
                    speculator=speculator,
688
                    dtype=dtype,
689
690
                    trust_remote_code=trust_remote_code,
                )
691
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
692
        else:
693
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
694
                return FlashRWSharded(
695
696
697
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
698
                    speculator=speculator,
699
                    dtype=dtype,
700
701
702
703
704
705
706
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
707
                    speculator=speculator,
708
                    dtype=dtype,
709
710
711
                    trust_remote_code=trust_remote_code,
                )

712
    if model_type == MISTRAL:
713
        if FLASH_ATTENTION:
714
715
716
717
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
718
                speculator=speculator,
719
720
721
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
722
723
724
725
726
727
728
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
729
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
730
731
732
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
733

734
    if model_type == MIXTRAL:
735
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
736
737
738
739
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
740
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
741
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
742
743
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
744
745
746
747
748
749
750
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
751
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
752
753
754
755
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

756
    if model_type == STARCODER2:
757
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
758
759
760
761
762
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
763
764
765
766
767
768
769
770
771
772
773
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
774
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
775
776
777
778
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

779
    if model_type == QWEN2:
780
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
795
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
796
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
797
798
                trust_remote_code=trust_remote_code,
            )
799

800
    if model_type == OPT:
801
        return OPTSharded(
802
803
804
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
805
            speculator=speculator,
806
807
            dtype=dtype,
            trust_remote_code=trust_remote_code,
808
        )
809

810
    if model_type == T5:
811
812
813
814
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
815
            speculator=speculator,
816
            dtype=dtype,
817
818
            trust_remote_code=trust_remote_code,
        )
819
    if model_type == IDEFICS:
820
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
821
822
823
824
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
825
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
826
827
828
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
829
830
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
831
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
832
833
834
835
836
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
837
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
838
839
840
841
842
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
843
844
845
846
847
848
849
850
851
852
853
854
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
855

856
    if model_type == LLAVA_NEXT:
857
858
859
860
861
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
862
                speculator=speculator,
863
864
865
866
867
868
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

869
    if sharded:
870
        raise NotImplementedError("sharded is not supported for AutoModel")
871
    if quantize == "gptq":
872
        raise NotImplementedError(
873
874
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
875
    if quantize == "awq":
876
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
877
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
878
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
879
    elif quantize == "eetq":
880
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
881
882
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
883
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
884
        return CausalLM(
885
886
887
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
888
            speculator=speculator,
889
890
            dtype=dtype,
            trust_remote_code=trust_remote_code,
891
        )
892
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
893
        return Seq2SeqLM(
894
895
896
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
897
            speculator=speculator,
898
899
            dtype=dtype,
            trust_remote_code=trust_remote_code,
900
901
        )

902
    auto_map = config_dict.get("auto_map", None)
903
904
905
906
907
908
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
909
                speculator=speculator,
910
                dtype=dtype,
911
912
                trust_remote_code=trust_remote_code,
            )
913
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
914
915
916
917
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
918
                speculator=speculator,
919
                dtype=dtype,
920
921
                trust_remote_code=trust_remote_code,
            )
922
923

    raise ValueError(f"Unsupported model type {model_type}")