__init__.py 30.8 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
14
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
15
from text_generation_server.models.bloom import BLOOMSharded
16
from text_generation_server.models.mpt import MPTSharded
17
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
18
from text_generation_server.models.rw import RW
19
20
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
21
22
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
23
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
24
from text_generation_server.models.phi import Phi
25

26
27
from text_generation_server.utils.import_utils import SYSTEM

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

50
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
51

52
FLASH_ATTENTION = True
53

54
try:
55
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
56
    from text_generation_server.models.flash_rw import FlashRWSharded
57
    from text_generation_server.models.flash_gpt2 import FlashGPT2
58
59
60
61
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
65
66
67
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
68
69
70
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
    from text_generation_server.models.flash_gemma2 import (
        FlashGemma2,
    )
drbh's avatar
drbh committed
74
75
76
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
77
78
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
79
    )
80
    from text_generation_server.models.idefics import IDEFICSSharded
81
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
82
    from text_generation_server.models.idefics2 import Idefics2
83
84
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
85
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
86
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
87
    from text_generation_server.models.flash_dbrx import FlashDbrx
88
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
89
90
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
91
    SUPPORTS_WINDOWING = False
92
    FLASH_ATTENTION = False
93

94
if FLASH_ATTENTION:
95
    __all__.append(FlashCausalLM)
96
    __all__.append(FlashGPT2)
97
    __all__.append(FlashNeoXSharded)
98
    __all__.append(FlashRWSharded)
99
    __all__.append(FlashSantacoderSharded)
100
    __all__.append(FlashLlama)
101
    __all__.append(IDEFICSSharded)
102
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
103
    __all__.append(FlashMixtral)
104
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
105
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
106
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
107
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
108
    __all__.append(FlashGemma)
Nicolas Patry's avatar
Nicolas Patry committed
109
    __all__.append(FlashGemma2)
OlivierDehaene's avatar
OlivierDehaene committed
110
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
111

drbh's avatar
drbh committed
112
113
114
115
116
117
118
119
120
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
121

122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
209
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


264
def get_model(
265
    model_id: str,
drbh's avatar
drbh committed
266
    lora_adapter_ids: Optional[List[str]],
267
268
269
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
270
    speculate: Optional[int],
271
    dtype: Optional[str],
272
    trust_remote_code: bool,
273
    max_input_tokens: int,
274
) -> Model:
275
    global FLASH_ATTENTION
276
    if dtype is None:
277
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
278
279
280
281
282
283
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
284
285
286
287
288
289
290
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
291
292
293
294
295
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
296
297
298
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
299
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
300

Nicolas Patry's avatar
Nicolas Patry committed
301
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
302
    if "medusa_num_heads" in config_dict:
303
304
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
305
306
307
308
309
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
310
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
311
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
312
                )
Nicolas Patry's avatar
Nicolas Patry committed
313
314
315
316
317
318
319
320
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
321
322
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
323
324
325
326
327
328
329
330
331
332
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
333
334
335
336
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
337
        else:
Nicolas Patry's avatar
Nicolas Patry committed
338
339
340
341
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
342

Nicolas Patry's avatar
Nicolas Patry committed
343
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
397
398
399
400
401
402
403
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
404
405
406
407
408
409
410
411
412
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
413
414
415
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
416
        if method in {"gptq", "awq", "exl2"}:
417
418
419
420
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
421

422
423
424
425
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
426
    sliding_window = config_dict.get("sliding_window", -1)
427
428
429
430
431
432
433
434

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
435
        )
436

437
    if model_type == MAMBA:
drbh's avatar
drbh committed
438
439
440
441
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
442
            speculator=speculator,
drbh's avatar
drbh committed
443
444
445
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
446

OlivierDehaene's avatar
OlivierDehaene committed
447
448
449
450
451
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
452
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
453
454
455
456
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

457
    if (
458
459
        model_type == GPT_BIGCODE
        or model_type == GPT2
460
461
        and model_id.startswith("bigcode/")
    ):
462
        if FLASH_ATTENTION:
463
464
465
466
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
467
                speculator=speculator,
468
                dtype=dtype,
469
470
                trust_remote_code=trust_remote_code,
            )
471
472
473
474
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
475
        else:
476
            return SantaCoder(
477
478
479
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
480
                speculator=speculator,
481
                dtype=dtype,
482
483
                trust_remote_code=trust_remote_code,
            )
484

485
    if model_type == BLOOM:
486
        return BLOOMSharded(
487
488
489
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
490
            speculator=speculator,
491
492
            dtype=dtype,
            trust_remote_code=trust_remote_code,
493
        )
494
    elif model_type == MPT:
495
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
496
497
498
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
499
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
500
501
            dtype=dtype,
            trust_remote_code=trust_remote_code,
502
        )
503
    elif model_type == GPT2:
504
        if FLASH_ATTENTION:
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
            try:
                return FlashGPT2(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
                return CausalLM(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
525
526
527
528
529
530
531
532
533
534
535
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
536
    elif model_type == GPT_NEOX:
537
538
539
540
541
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
542
                speculator=speculator,
543
                dtype=dtype,
544
545
546
547
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
548
549
550
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
551
                speculator=speculator,
552
                dtype=dtype,
553
554
                trust_remote_code=trust_remote_code,
            )
555
        else:
556
            return CausalLM(
557
558
559
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
560
                speculator=speculator,
561
                dtype=dtype,
562
563
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
564

565
    elif model_type == PHI:
drbh's avatar
drbh committed
566
567
568
569
570
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
571
                speculator=speculator,
drbh's avatar
drbh committed
572
573
574
575
576
577
578
579
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
580
                speculator=speculator,
drbh's avatar
drbh committed
581
582
583
584
585
586
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
587
588
589
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
590
591
592
593
594
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
595
                speculator=speculator,
drbh's avatar
drbh committed
596
597
598
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
599

600
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
601
602
        if FLASH_ATTENTION:
            return FlashLlama(
603
604
605
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
606
                speculator=speculator,
607
                dtype=dtype,
608
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
609
                lora_adapter_ids=lora_adapter_ids,
610
            )
611
612
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
613
        else:
614
            return CausalLM(
615
616
617
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
618
                speculator=speculator,
619
                dtype=dtype,
620
621
                trust_remote_code=trust_remote_code,
            )
622
    if model_type == GEMMA:
623
624
625
626
627
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
628
                speculator=speculator,
629
630
631
632
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
633
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
634
635
636
637
638
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
640
641
642
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
            return FlashGemma2(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
664

665
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
666
667
668
669
670
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
671
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
672
673
674
675
676
677
678
679
680
681
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
682
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
683
684
685
686
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

687
    if model_type == DBRX:
688
689
690
691
692
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
693
                speculator=speculator,
694
695
696
697
698
699
700
701
702
703
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
704
                speculator=speculator,
705
706
707
708
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

709
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
710
711
        if sharded:
            if FLASH_ATTENTION:
712
                if config_dict.get("alibi", False):
713
714
715
716
717
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
718
                    speculator=speculator,
719
                    dtype=dtype,
720
721
                    trust_remote_code=trust_remote_code,
                )
722
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
723
        else:
724
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
725
                return FlashRWSharded(
726
727
728
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
729
                    speculator=speculator,
730
                    dtype=dtype,
731
732
733
734
735
736
737
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
738
                    speculator=speculator,
739
                    dtype=dtype,
740
741
742
                    trust_remote_code=trust_remote_code,
                )

743
    if model_type == MISTRAL:
744
        if FLASH_ATTENTION:
745
746
747
748
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
749
                speculator=speculator,
750
751
752
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
753
754
755
756
757
758
759
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
760
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
761
762
763
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
764

765
    if model_type == MIXTRAL:
766
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
767
768
769
770
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
771
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
772
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
773
774
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
775
776
777
778
779
780
781
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
782
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
783
784
785
786
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

787
    if model_type == STARCODER2:
788
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
789
790
791
792
793
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
794
795
796
797
798
799
800
801
802
803
804
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
805
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
806
807
808
809
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

810
    if model_type == QWEN2:
811
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
812
813
814
815
816
817
818
819
820
821
822
823
824
825
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
826
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
827
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
828
829
                trust_remote_code=trust_remote_code,
            )
830

831
    if model_type == OPT:
832
        return OPTSharded(
833
834
835
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
836
            speculator=speculator,
837
838
            dtype=dtype,
            trust_remote_code=trust_remote_code,
839
        )
840

841
    if model_type == T5:
842
843
844
845
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
846
            speculator=speculator,
847
            dtype=dtype,
848
849
            trust_remote_code=trust_remote_code,
        )
850
    if model_type == IDEFICS:
851
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
852
853
854
855
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
856
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
857
858
859
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
860
861
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
862
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
863
864
865
866
867
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
868
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
869
870
871
872
873
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
874
875
876
877
878
879
880
881
882
883
884
885
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
886

887
    if model_type == LLAVA_NEXT:
888
889
890
891
892
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
893
                speculator=speculator,
894
895
896
897
898
899
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

900
    if sharded:
901
        raise NotImplementedError("sharded is not supported for AutoModel")
902
    if quantize == "gptq":
903
        raise NotImplementedError(
904
905
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
906
    if quantize == "awq":
907
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
908
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
909
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
910
    elif quantize == "eetq":
911
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
912
913
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
914
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
915
        return CausalLM(
916
917
918
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
919
            speculator=speculator,
920
921
            dtype=dtype,
            trust_remote_code=trust_remote_code,
922
        )
923
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
924
        return Seq2SeqLM(
925
926
927
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
928
            speculator=speculator,
929
930
            dtype=dtype,
            trust_remote_code=trust_remote_code,
931
932
        )

933
    auto_map = config_dict.get("auto_map", None)
934
935
936
937
938
939
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
940
                speculator=speculator,
941
                dtype=dtype,
942
943
                trust_remote_code=trust_remote_code,
            )
944
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
945
946
947
948
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
949
                speculator=speculator,
950
                dtype=dtype,
951
952
                trust_remote_code=trust_remote_code,
            )
953
954

    raise ValueError(f"Unsupported model type {model_type}")