__init__.py 30.8 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
14
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
15
from text_generation_server.models.flash_causal_lm import FlashCausalLM
16
from text_generation_server.models.bloom import BLOOMSharded
17
from text_generation_server.models.mpt import MPTSharded
18
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
19
from text_generation_server.models.rw import RW
20
21
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
22
23
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
24
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
25
from text_generation_server.models.phi import Phi
26

27
28
from text_generation_server.utils.import_utils import SYSTEM

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

51
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
52

53
FLASH_ATTENTION = True
54

55
try:
56
    from text_generation_server.models.flash_rw import FlashRWSharded
57
    from text_generation_server.models.flash_gpt2 import FlashGPT2
58
59
60
61
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
65
66
67
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
68
69
70
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
    from text_generation_server.models.flash_gemma2 import (
        FlashGemma2,
    )
drbh's avatar
drbh committed
74
75
76
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
77
78
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
79
    )
80
    from text_generation_server.models.idefics import IDEFICSSharded
81
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
82
    from text_generation_server.models.idefics2 import Idefics2
83
84
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
85
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
86
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
87
    from text_generation_server.models.flash_dbrx import FlashDbrx
88
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
89
90
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
91
    SUPPORTS_WINDOWING = False
92
    FLASH_ATTENTION = False
93

94
if FLASH_ATTENTION:
95
    __all__.append(FlashGPT2)
96
    __all__.append(FlashNeoXSharded)
97
    __all__.append(FlashRWSharded)
98
    __all__.append(FlashSantacoderSharded)
99
    __all__.append(FlashLlama)
100
    __all__.append(IDEFICSSharded)
101
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
102
    __all__.append(FlashMixtral)
103
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
104
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
105
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
106
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
107
    __all__.append(FlashGemma)
Nicolas Patry's avatar
Nicolas Patry committed
108
    __all__.append(FlashGemma2)
OlivierDehaene's avatar
OlivierDehaene committed
109
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
110

drbh's avatar
drbh committed
111
112
113
114
115
116
117
118
119
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
120

121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
Nicolas Patry's avatar
Nicolas Patry committed
150
151
152
153
154
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
208
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


263
def get_model(
264
    model_id: str,
drbh's avatar
drbh committed
265
    lora_adapter_ids: Optional[List[str]],
266
267
268
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
269
    speculate: Optional[int],
270
    dtype: Optional[str],
271
    trust_remote_code: bool,
272
    max_input_tokens: int,
273
) -> Model:
274
    global FLASH_ATTENTION
275
    if dtype is None:
276
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
277
278
279
280
281
282
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
283
284
285
286
287
288
289
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
290
291
292
293
294
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
295
296
297
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
298
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
299

Nicolas Patry's avatar
Nicolas Patry committed
300
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
301
    if "medusa_num_heads" in config_dict:
302
303
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
304
305
306
307
308
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
309
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
310
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
311
                )
Nicolas Patry's avatar
Nicolas Patry committed
312
313
314
315
316
317
318
319
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
320
321
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
322
323
324
325
326
327
328
329
330
331
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
332
333
334
335
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
336
        else:
Nicolas Patry's avatar
Nicolas Patry committed
337
338
339
340
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
341

Nicolas Patry's avatar
Nicolas Patry committed
342
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
396
397
398
399
400
401
402
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
403
404
405
406
407
408
409
410
411
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
412
413
414
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
415
        if method in {"gptq", "awq", "exl2"}:
416
417
418
419
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
420

421
422
423
424
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
425
    sliding_window = config_dict.get("sliding_window", -1)
426
427
428
429
430
431
432
433

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
434
        )
435

436
    if model_type == MAMBA:
drbh's avatar
drbh committed
437
438
439
440
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
441
            speculator=speculator,
drbh's avatar
drbh committed
442
443
444
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
445

OlivierDehaene's avatar
OlivierDehaene committed
446
447
448
449
450
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
451
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
452
453
454
455
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

456
    if (
457
458
        model_type == GPT_BIGCODE
        or model_type == GPT2
459
460
        and model_id.startswith("bigcode/")
    ):
461
        if FLASH_ATTENTION:
462
463
464
465
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
466
                speculator=speculator,
467
                dtype=dtype,
468
469
                trust_remote_code=trust_remote_code,
            )
470
471
472
473
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
474
        else:
475
            return SantaCoder(
476
477
478
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
479
                speculator=speculator,
480
                dtype=dtype,
481
482
                trust_remote_code=trust_remote_code,
            )
483

484
    if model_type == BLOOM:
485
        return BLOOMSharded(
486
487
488
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
489
            speculator=speculator,
490
491
            dtype=dtype,
            trust_remote_code=trust_remote_code,
492
        )
493
    elif model_type == MPT:
494
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
495
496
497
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
498
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
499
500
            dtype=dtype,
            trust_remote_code=trust_remote_code,
501
        )
502
    elif model_type == GPT2:
503
        if FLASH_ATTENTION:
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            try:
                return FlashGPT2(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
                return CausalLM(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
524
525
526
527
528
529
530
531
532
533
534
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
535
    elif model_type == GPT_NEOX:
536
537
538
539
540
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
541
                speculator=speculator,
542
                dtype=dtype,
543
544
545
546
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
547
548
549
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
550
                speculator=speculator,
551
                dtype=dtype,
552
553
                trust_remote_code=trust_remote_code,
            )
554
        else:
555
            return CausalLM(
556
557
558
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
559
                speculator=speculator,
560
                dtype=dtype,
561
562
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
563

564
    elif model_type == PHI:
drbh's avatar
drbh committed
565
566
567
568
569
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
570
                speculator=speculator,
drbh's avatar
drbh committed
571
572
573
574
575
576
577
578
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
579
                speculator=speculator,
drbh's avatar
drbh committed
580
581
582
583
584
585
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
586
587
588
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
589
590
591
592
593
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
594
                speculator=speculator,
drbh's avatar
drbh committed
595
596
597
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
598

599
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
600
601
        if FLASH_ATTENTION:
            return FlashLlama(
602
603
604
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
605
                speculator=speculator,
606
                dtype=dtype,
607
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
608
                lora_adapter_ids=lora_adapter_ids,
609
            )
610
611
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
612
        else:
613
            return CausalLM(
614
615
616
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
617
                speculator=speculator,
618
                dtype=dtype,
619
620
                trust_remote_code=trust_remote_code,
            )
621
    if model_type == GEMMA:
622
623
624
625
626
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
627
                speculator=speculator,
628
629
630
631
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
632
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
633
634
635
636
637
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
638
                speculator=speculator,
639
640
641
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
            return FlashGemma2(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
663

664
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
665
666
667
668
669
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
670
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
671
672
673
674
675
676
677
678
679
680
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
681
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
682
683
684
685
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

686
    if model_type == DBRX:
687
688
689
690
691
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
692
                speculator=speculator,
693
694
695
696
697
698
699
700
701
702
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
703
                speculator=speculator,
704
705
706
707
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

708
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
709
710
        if sharded:
            if FLASH_ATTENTION:
711
                if config_dict.get("alibi", False):
712
713
714
715
716
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
717
                    speculator=speculator,
718
                    dtype=dtype,
719
720
                    trust_remote_code=trust_remote_code,
                )
721
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
722
        else:
723
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
724
                return FlashRWSharded(
725
726
727
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
728
                    speculator=speculator,
729
                    dtype=dtype,
730
731
732
733
734
735
736
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
737
                    speculator=speculator,
738
                    dtype=dtype,
739
740
741
                    trust_remote_code=trust_remote_code,
                )

742
    if model_type == MISTRAL:
743
        if FLASH_ATTENTION:
744
745
746
747
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
748
                speculator=speculator,
749
750
751
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
752
753
754
755
756
757
758
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
759
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
760
761
762
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
763

764
    if model_type == MIXTRAL:
765
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
766
767
768
769
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
770
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
771
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
772
773
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
774
775
776
777
778
779
780
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
781
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
782
783
784
785
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

786
    if model_type == STARCODER2:
787
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
788
789
790
791
792
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
793
794
795
796
797
798
799
800
801
802
803
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
804
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
805
806
807
808
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

809
    if model_type == QWEN2:
810
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
811
812
813
814
815
816
817
818
819
820
821
822
823
824
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
825
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
826
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
827
828
                trust_remote_code=trust_remote_code,
            )
829

830
    if model_type == OPT:
831
        return OPTSharded(
832
833
834
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
835
            speculator=speculator,
836
837
            dtype=dtype,
            trust_remote_code=trust_remote_code,
838
        )
839

840
    if model_type == T5:
841
842
843
844
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
845
            speculator=speculator,
846
            dtype=dtype,
847
848
            trust_remote_code=trust_remote_code,
        )
849
    if model_type == IDEFICS:
850
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
851
852
853
854
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
855
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
856
857
858
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
859
860
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
861
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
862
863
864
865
866
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
867
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
868
869
870
871
872
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
873
874
875
876
877
878
879
880
881
882
883
884
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
885

886
    if model_type == LLAVA_NEXT:
887
888
889
890
891
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
892
                speculator=speculator,
893
894
895
896
897
898
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

899
    if sharded:
900
        raise NotImplementedError("sharded is not supported for AutoModel")
901
    if quantize == "gptq":
902
        raise NotImplementedError(
903
904
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
905
    if quantize == "awq":
906
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
907
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
908
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
909
    elif quantize == "eetq":
910
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
911
912
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
913
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
914
        return CausalLM(
915
916
917
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
918
            speculator=speculator,
919
920
            dtype=dtype,
            trust_remote_code=trust_remote_code,
921
        )
922
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
923
        return Seq2SeqLM(
924
925
926
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
927
            speculator=speculator,
928
929
            dtype=dtype,
            trust_remote_code=trust_remote_code,
930
931
        )

932
    auto_map = config_dict.get("auto_map", None)
933
934
935
936
937
938
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
939
                speculator=speculator,
940
                dtype=dtype,
941
942
                trust_remote_code=trust_remote_code,
            )
943
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
944
945
946
947
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
948
                speculator=speculator,
949
                dtype=dtype,
950
951
                trust_remote_code=trust_remote_code,
            )
952
953

    raise ValueError(f"Unsupported model type {model_type}")