__init__.py 36.2 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
19
from text_generation_server.models.bloom import BloomCausalLMBatch
20
21
22
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
23
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
24
25
26
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
35

36
37
from text_generation_server.utils.import_utils import SYSTEM

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "get_model",
]

57
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
58

59
FLASH_ATTENTION = True
60

61
try:
62
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
63
64
65
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
66
    )
67
68
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
69
    )
70
71
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
72
    )
73
74
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
75
    )
76
77
78
79
80
81
82
83
84
85
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
86
    )
drbh's avatar
drbh committed
87
    from text_generation_server.models.pali_gemma import (
88
        PaliGemmaBatch,
drbh's avatar
drbh committed
89
    )
90
91
92
93
94
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
95
    )
96
    from text_generation_server.models.idefics import IDEFICSSharded
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
122
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
123
124
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
125
    SUPPORTS_WINDOWING = False
126
    FLASH_ATTENTION = False
127

128
if FLASH_ATTENTION:
129
    __all__.append(FlashCausalLM)
130
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
131

drbh's avatar
drbh committed
132
133
134
135
136
137
138
139
140
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
141

142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
171
172
173
174
175
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
234
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


289
def get_model(
290
    model_id: str,
drbh's avatar
drbh committed
291
    lora_adapter_ids: Optional[List[str]],
292
293
294
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
295
    speculate: Optional[int],
296
    dtype: Optional[str],
297
    trust_remote_code: bool,
298
    max_input_tokens: int,
299
) -> Model:
300
    global FLASH_ATTENTION
301
    if dtype is None:
302
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
303
304
305
306
307
308
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
309
310
311
312
313
314
315
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
316
317
318
319
320
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
321
322
323
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
324
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
325

Nicolas Patry's avatar
Nicolas Patry committed
326
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
327
    if "medusa_num_heads" in config_dict:
328
329
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
330
331
332
333
334
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
335
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
336
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
337
                )
Nicolas Patry's avatar
Nicolas Patry committed
338
339
340
341
342
343
344
345
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
346
347
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
348
349
350
351
352
353
354
355
356
357
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
358
359
360
361
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
362
        else:
Nicolas Patry's avatar
Nicolas Patry committed
363
364
365
366
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
367

Nicolas Patry's avatar
Nicolas Patry committed
368
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
422
423
424
425
426
427
428
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
429
430
431
432
433
434
435
436
437
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
438
439
440
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
441
        if method in {"gptq", "awq", "exl2"}:
442
443
444
445
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
446

447
448
449
450
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
451
    sliding_window = config_dict.get("sliding_window", -1)
452
453
454
455
456
457
458
459

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
460
        )
461

462
    if model_type == MAMBA:
drbh's avatar
drbh committed
463
464
465
466
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
467
            speculator=speculator,
drbh's avatar
drbh committed
468
469
470
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
471

OlivierDehaene's avatar
OlivierDehaene committed
472
    if model_id.startswith("facebook/galactica"):
473
474
475
476
477
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
478
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
479
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
480
481
            dtype=dtype,
            trust_remote_code=trust_remote_code,
482
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
483
484
        )

485
    if (
486
487
        model_type == GPT_BIGCODE
        or model_type == GPT2
488
489
        and model_id.startswith("bigcode/")
    ):
490
        if FLASH_ATTENTION:
491
492
493
494
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
495
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
496
                speculator=speculator,
497
                dtype=dtype,
498
                trust_remote_code=trust_remote_code,
499
500
501
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
502
            )
503
504
505
506
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
507
        else:
508
509
510
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
511
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
512
                speculator=speculator,
513
                dtype=dtype,
514
515
                trust_remote_code=trust_remote_code,
            )
516

517
    if model_type == BLOOM:
518
519
520
521
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
522
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
523
            speculator=speculator,
524
525
            dtype=dtype,
            trust_remote_code=trust_remote_code,
526
            batch_class=BloomCausalLMBatch,
527
        )
528
    elif model_type == MPT:
529
530
531
532
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
533
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
534
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
535
536
            dtype=dtype,
            trust_remote_code=trust_remote_code,
537
            batch_class=CausalLMBatchKeysLast,
538
        )
539
    elif model_type == GPT2:
540
        if FLASH_ATTENTION:
541
            try:
542
543
544
545
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
546
547
548
549
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
550
                    lora_adapter_ids=lora_adapter_ids,
551
552
553
554
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
555
                return CausalLM.fallback(
556
557
558
559
560
561
562
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
563
564
565
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
566
            return CausalLM.fallback(
567
568
569
570
571
572
573
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
574
    elif model_type == GPT_NEOX:
575
        if FLASH_ATTENTION:
576
577
578
579
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

580
581
582
583
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
584
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
585
                speculator=speculator,
586
                dtype=dtype,
587
                trust_remote_code=trust_remote_code,
588
                lora_adapter_ids=lora_adapter_ids,
589
                config_class=GPTNeoXConfig,
590
591
            )
        elif sharded:
592
593
594
595
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
596
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
597
                speculator=speculator,
598
                dtype=dtype,
599
600
                trust_remote_code=trust_remote_code,
            )
601
        else:
602
            return CausalLM.fallback(
603
604
605
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
606
                speculator=speculator,
607
                dtype=dtype,
608
609
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
610

611
    elif model_type == PHI:
drbh's avatar
drbh committed
612
        if FLASH_ATTENTION:
613
614
615
616
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
617
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
618
                speculator=speculator,
drbh's avatar
drbh committed
619
620
                dtype=dtype,
                trust_remote_code=trust_remote_code,
621
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
622
623
            )
        else:
624
            return CausalLM.fallback(
drbh's avatar
drbh committed
625
626
627
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
628
                speculator=speculator,
drbh's avatar
drbh committed
629
630
631
632
633
634
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
635
636
637
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
638
        else:
639
640
641
642
643
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
644
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
645
                speculator=speculator,
drbh's avatar
drbh committed
646
647
648
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
649

650
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
651
        if FLASH_ATTENTION:
652
653
654
655
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
656
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
657
                speculator=speculator,
658
                dtype=dtype,
659
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
660
                lora_adapter_ids=lora_adapter_ids,
661
            )
662
663
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
664
        else:
665
            return CausalLM.fallback(
666
667
668
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
669
                speculator=speculator,
670
                dtype=dtype,
671
672
                trust_remote_code=trust_remote_code,
            )
673
    if model_type == GEMMA:
674
        if FLASH_ATTENTION:
675
676
677
678
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
679
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
680
                speculator=speculator,
681
                dtype=dtype,
682
683
                # Works better for these models
                default_dtype=torch.bfloat16,
684
                trust_remote_code=trust_remote_code,
685
                lora_adapter_ids=lora_adapter_ids,
686
687
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
688
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
689
        else:
690
            return CausalLM.fallback(
691
692
693
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
694
                speculator=speculator,
695
696
697
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
698
699
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
700
701
702
703
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
704
705
706
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
707
708
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
709
                trust_remote_code=trust_remote_code,
710
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
711
712
713
714
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
715
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
716
717
718
719
720
721
722
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
723

724
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
725
        if FLASH_ATTENTION:
726
727
728
729
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
730
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
731
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
732
733
                dtype=dtype,
                trust_remote_code=trust_remote_code,
734
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
735
736
737
738
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
739
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
740
741
742
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
743
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
744
745
746
747
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

748
    if model_type == DBRX:
749
        if FLASH_ATTENTION:
750
751
752
753
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
754
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
755
                speculator=speculator,
756
                dtype=dtype,
757
758
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
759
                trust_remote_code=trust_remote_code,
760
761
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
762
763
764
765
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
766
            return CausalLM.fallback(
767
768
769
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
770
                speculator=speculator,
771
772
773
774
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

775
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
776
777
        if sharded:
            if FLASH_ATTENTION:
778
                if config_dict.get("alibi", False):
779
                    raise NotImplementedError("sharded is not supported for this model")
780
781
782
783
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
784
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
785
                    speculator=speculator,
786
                    dtype=dtype,
787
788
789
790
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
791
                    trust_remote_code=trust_remote_code,
792
793
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
794
                )
795
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
796
        else:
797
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
798
799
800
801
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
802
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
803
                    speculator=speculator,
804
                    dtype=dtype,
805
806
807
808
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
809
                    trust_remote_code=trust_remote_code,
810
811
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
812
813
                )
            else:
814
                return CausalLM.fallback(
815
816
817
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
818
                    speculator=speculator,
819
                    dtype=dtype,
820
821
822
                    trust_remote_code=trust_remote_code,
                )

823
    if model_type == MISTRAL:
824
        if FLASH_ATTENTION:
825
826
827
828
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
829
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
830
                speculator=speculator,
831
832
                dtype=dtype,
                trust_remote_code=trust_remote_code,
833
                lora_adapter_ids=lora_adapter_ids,
834
            )
OlivierDehaene's avatar
OlivierDehaene committed
835
836
837
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
838
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
839
840
841
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
842
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
843
844
845
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
846

847
    if model_type == MIXTRAL:
848
        if FLASH_ATTENTION:
849
850
851
852
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
853
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
854
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
855
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
856
                trust_remote_code=trust_remote_code,
857
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
858
            )
OlivierDehaene's avatar
OlivierDehaene committed
859
860
861
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
862
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
863
864
865
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
866
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
867
868
869
870
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

871
    if model_type == STARCODER2:
872
        if FLASH_ATTENTION:
873
874
875
876
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
877
                quantize=quantize,
878
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
879
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
880
                trust_remote_code=trust_remote_code,
881
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
882
883
884
885
886
887
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
888
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
889
890
891
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
892
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
893
894
895
896
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

897
    if model_type == QWEN2:
898
        if FLASH_ATTENTION:
899
900
901
902
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
903
                quantize=quantize,
904
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
905
906
                dtype=dtype,
                trust_remote_code=trust_remote_code,
907
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
908
909
910
911
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
912
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
913
914
915
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
916
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
917
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
918
919
                trust_remote_code=trust_remote_code,
            )
920

921
    if model_type == OPT:
922
923
924
925
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
926
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
927
            speculator=speculator,
928
929
            dtype=dtype,
            trust_remote_code=trust_remote_code,
930
        )
931

932
    if model_type == T5:
933
934
935
936
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
937
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
938
            speculator=speculator,
939
            dtype=dtype,
940
            trust_remote_code=trust_remote_code,
941
942
943
944
945
946
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
947
        )
948
    if model_type == IDEFICS:
949
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
950
951
952
953
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
954
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
955
956
957
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
958
959
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
960
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
961
        if FLASH_ATTENTION:
962
963
964
965
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
966
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
967
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
968
969
                dtype=dtype,
                trust_remote_code=trust_remote_code,
970
971
972
973
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
974
975
976
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
977
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
978
        if FLASH_ATTENTION:
979
980
981
982
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
983
984
985
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
986
987
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
988
                trust_remote_code=trust_remote_code,
989
990
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
991
992
993
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
994

995
    if model_type == LLAVA_NEXT:
996
        if FLASH_ATTENTION:
997
998
999
1000
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1001
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1002
                speculator=speculator,
1003
1004
1005
1006
1007
1008
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1009
    if sharded:
1010
        raise NotImplementedError("sharded is not supported for AutoModel")
1011
    if quantize == "gptq":
1012
        raise NotImplementedError(
1013
1014
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1015
    if quantize == "awq":
1016
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1017
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1018
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1019
    elif quantize == "eetq":
1020
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1021
1022
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1023
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1024
        return CausalLM.fallback(
1025
1026
1027
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1028
            speculator=speculator,
1029
1030
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1031
        )
1032
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1033
        return Seq2SeqLM.fallback(
1034
1035
1036
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1037
            speculator=speculator,
1038
1039
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1040
1041
        )

1042
    auto_map = config_dict.get("auto_map", None)
1043
1044
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1045
            return CausalLM.fallback(
1046
1047
1048
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1049
                speculator=speculator,
1050
                dtype=dtype,
1051
1052
                trust_remote_code=trust_remote_code,
            )
1053
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1054
            return Seq2SeqLM.fallback(
1055
1056
1057
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1058
                speculator=speculator,
1059
                dtype=dtype,
1060
1061
                trust_remote_code=trust_remote_code,
            )
1062
1063

    raise ValueError(f"Unsupported model type {model_type}")