__init__.py 38.2 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
19
from text_generation_server.models.bloom import BloomCausalLMBatch
20
21
22
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
23
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
24
25
26
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
35

36
from text_generation_server.utils.import_utils import SYSTEM
37
from text_generation_server.utils.log import log_master
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
    "get_model",
]

56
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
57

58
FLASH_ATTENTION = True
59

60
try:
61
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
62
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
63
64
65
66
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
67
68
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
69
    )
70
71
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
72
    )
73
74
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
75
    )
76
77
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
78
    )
79
80
81
82
83
84
85
86
87
88
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
89
    )
drbh's avatar
drbh committed
90
    from text_generation_server.models.pali_gemma import (
91
        PaliGemmaBatch,
drbh's avatar
drbh committed
92
    )
93
94
95
96
97
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
98
    )
99
    from text_generation_server.models.idefics import IDEFICSSharded
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
125
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
126
except ImportError as e:
127
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
128
    SUPPORTS_WINDOWING = False
129
    FLASH_ATTENTION = False
130

131
if FLASH_ATTENTION:
132
    __all__.append(FlashCausalLM)
133
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
134

drbh's avatar
drbh committed
135
136
137
138
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
139
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
140
141
142
143
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
144

145

146
class ModelType(enum.Enum):
147
148
149
150
151
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
179
180
181
182
183
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
184
185
186
187
188
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
242
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


297
def get_model(
298
    model_id: str,
drbh's avatar
drbh committed
299
    lora_adapter_ids: Optional[List[str]],
300
301
302
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
303
    speculate: Optional[int],
304
    dtype: Optional[str],
305
    trust_remote_code: bool,
306
    max_input_tokens: int,
307
) -> Model:
308
    global FLASH_ATTENTION
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

327
    if dtype is None:
328
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
329
330
            # These quantizers only work with float16 params.
            dtype = torch.float16
331
        elif quantize == "fp8":
332
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
333

334
            if FBGEMM_DYN_AVAILABLE:
335
336
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
337
338
339
340
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
341
342
343
344
345
346
347
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
348
349
350
351
352
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
353
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
354
    if "medusa_num_heads" in config_dict:
355
356
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
357
358
359
360
361
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
362
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
363
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
364
                )
Nicolas Patry's avatar
Nicolas Patry committed
365
366
367
368
369
370
371
372
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
373
374
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
375
376
377
378
379
380
381
382
383
384
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
385
386
387
388
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
389
        else:
Nicolas Patry's avatar
Nicolas Patry committed
390
391
392
393
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
394

Nicolas Patry's avatar
Nicolas Patry committed
395
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
449
450
451
452
453
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
454
455
456
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
457

drbh's avatar
drbh committed
458
459
460
461
462
463
464
465
466
467
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

468
469
470
471
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
472
    sliding_window = config_dict.get("sliding_window", -1)
473
474
475
476
477
478
479
480

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
481
        )
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
517
518
519
520
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
521
            speculator=speculator,
drbh's avatar
drbh committed
522
523
524
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
525

OlivierDehaene's avatar
OlivierDehaene committed
526
    if model_id.startswith("facebook/galactica"):
527
528
529
530
531
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
532
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
533
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
534
535
            dtype=dtype,
            trust_remote_code=trust_remote_code,
536
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
537
538
        )

539
    if (
540
541
        model_type == GPT_BIGCODE
        or model_type == GPT2
542
543
        and model_id.startswith("bigcode/")
    ):
544
        if FLASH_ATTENTION:
545
546
547
548
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
549
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
550
                speculator=speculator,
551
                dtype=dtype,
552
                trust_remote_code=trust_remote_code,
553
554
555
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
556
            )
557
558
559
560
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
561
        else:
562
563
564
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
565
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
566
                speculator=speculator,
567
                dtype=dtype,
568
569
                trust_remote_code=trust_remote_code,
            )
570

571
    if model_type == BLOOM:
572
573
574
575
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
576
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
577
            speculator=speculator,
578
579
            dtype=dtype,
            trust_remote_code=trust_remote_code,
580
            batch_class=BloomCausalLMBatch,
581
        )
582
    elif model_type == MPT:
583
584
585
586
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
587
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
588
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
589
590
            dtype=dtype,
            trust_remote_code=trust_remote_code,
591
            batch_class=CausalLMBatchKeysLast,
592
        )
593
    elif model_type == GPT2:
594
        if FLASH_ATTENTION:
595
            try:
596
597
598
599
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
600
601
602
603
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
604
                    lora_adapter_ids=lora_adapter_ids,
605
606
607
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
608
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
609
                return CausalLM.fallback(
610
611
612
613
614
615
616
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
617
618
619
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
620
            return CausalLM.fallback(
621
622
623
624
625
626
627
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
628
    elif model_type == GPT_NEOX:
629
        if FLASH_ATTENTION:
630
631
632
633
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

634
635
636
637
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
638
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
640
                dtype=dtype,
641
                trust_remote_code=trust_remote_code,
642
                lora_adapter_ids=lora_adapter_ids,
643
                config_class=GPTNeoXConfig,
644
645
            )
        elif sharded:
646
647
648
649
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
650
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
651
                speculator=speculator,
652
                dtype=dtype,
653
654
                trust_remote_code=trust_remote_code,
            )
655
        else:
656
            return CausalLM.fallback(
657
658
659
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
660
                speculator=speculator,
661
                dtype=dtype,
662
663
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
664

665
    elif model_type == PHI:
drbh's avatar
drbh committed
666
        if FLASH_ATTENTION:
667
668
669
670
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
671
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
672
                speculator=speculator,
drbh's avatar
drbh committed
673
674
                dtype=dtype,
                trust_remote_code=trust_remote_code,
675
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
676
677
            )
        else:
678
            return CausalLM.fallback(
drbh's avatar
drbh committed
679
680
681
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
682
                speculator=speculator,
drbh's avatar
drbh committed
683
684
685
686
687
688
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
689
690
691
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
692
        else:
693
694
695
696
697
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
698
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
699
                speculator=speculator,
drbh's avatar
drbh committed
700
701
702
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
703

704
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
705
        if FLASH_ATTENTION:
706
707
708
709
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
710
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
711
                speculator=speculator,
712
                dtype=dtype,
713
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
714
                lora_adapter_ids=lora_adapter_ids,
715
            )
716
717
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
718
        else:
719
            return CausalLM.fallback(
720
721
722
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
723
                speculator=speculator,
724
                dtype=dtype,
725
726
                trust_remote_code=trust_remote_code,
            )
727
    if model_type == GEMMA:
728
        if FLASH_ATTENTION:
729
730
731
732
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
733
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
734
                speculator=speculator,
735
                dtype=dtype,
736
737
                # Works better for these models
                default_dtype=torch.bfloat16,
738
                trust_remote_code=trust_remote_code,
739
                lora_adapter_ids=lora_adapter_ids,
740
741
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
742
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
743
        else:
744
            return CausalLM.fallback(
745
746
747
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
748
                speculator=speculator,
749
750
751
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
752
753
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
754
755
756
757
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
758
759
760
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
761
762
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
763
                trust_remote_code=trust_remote_code,
764
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
765
766
767
768
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
769
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
770
771
772
773
774
775
776
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
777

778
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
779
        if FLASH_ATTENTION:
780
781
782
783
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
784
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
785
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
786
787
                dtype=dtype,
                trust_remote_code=trust_remote_code,
788
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
789
790
791
792
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
793
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
794
795
796
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
797
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
798
799
800
801
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

802
    if model_type == DBRX:
803
        if FLASH_ATTENTION:
804
805
806
807
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
808
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
809
                speculator=speculator,
810
                dtype=dtype,
811
812
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
813
                trust_remote_code=trust_remote_code,
814
815
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
816
817
818
819
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
820
            return CausalLM.fallback(
821
822
823
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
824
                speculator=speculator,
825
826
827
828
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

829
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
830
831
        if sharded:
            if FLASH_ATTENTION:
832
                if config_dict.get("alibi", False):
833
                    raise NotImplementedError("sharded is not supported for this model")
834
835
836
837
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
838
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
839
                    speculator=speculator,
840
                    dtype=dtype,
841
842
843
844
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
845
                    trust_remote_code=trust_remote_code,
846
847
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
848
                )
849
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
850
        else:
851
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
852
853
854
855
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
856
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
857
                    speculator=speculator,
858
                    dtype=dtype,
859
860
861
862
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
863
                    trust_remote_code=trust_remote_code,
864
865
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
866
867
                )
            else:
868
                return CausalLM.fallback(
869
870
871
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
872
                    speculator=speculator,
873
                    dtype=dtype,
874
875
876
                    trust_remote_code=trust_remote_code,
                )

877
    if model_type == MISTRAL:
878
        if FLASH_ATTENTION:
879
880
881
882
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
883
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
884
                speculator=speculator,
885
886
                dtype=dtype,
                trust_remote_code=trust_remote_code,
887
                lora_adapter_ids=lora_adapter_ids,
888
            )
OlivierDehaene's avatar
OlivierDehaene committed
889
890
891
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
892
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
893
894
895
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
896
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
897
898
899
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
900

901
    if model_type == MIXTRAL:
902
        if FLASH_ATTENTION:
903
904
905
906
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
907
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
908
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
909
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
910
                trust_remote_code=trust_remote_code,
911
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
912
            )
OlivierDehaene's avatar
OlivierDehaene committed
913
914
915
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
916
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
917
918
919
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
920
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
921
922
923
924
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

925
    if model_type == STARCODER2:
926
        if FLASH_ATTENTION:
927
928
929
930
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
931
                quantize=quantize,
932
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
933
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
934
                trust_remote_code=trust_remote_code,
935
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
936
937
938
939
940
941
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
942
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
943
944
945
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
946
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
947
948
949
950
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

951
    if model_type == QWEN2:
952
        if FLASH_ATTENTION:
953
954
955
956
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
957
                quantize=quantize,
958
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
959
960
                dtype=dtype,
                trust_remote_code=trust_remote_code,
961
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
962
963
964
965
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
966
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
967
968
969
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
970
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
971
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
972
973
                trust_remote_code=trust_remote_code,
            )
974

975
    if model_type == OPT:
976
977
978
979
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
980
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
981
            speculator=speculator,
982
983
            dtype=dtype,
            trust_remote_code=trust_remote_code,
984
        )
985

986
    if model_type == T5:
987
988
989
990
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
991
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
992
            speculator=speculator,
993
            dtype=dtype,
994
            trust_remote_code=trust_remote_code,
995
996
997
998
999
1000
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1001
        )
1002
    if model_type == IDEFICS:
1003
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1004
1005
1006
1007
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1008
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1009
1010
1011
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1012
1013
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1014
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1015
        if FLASH_ATTENTION:
1016
1017
1018
1019
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1020
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1021
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1022
1023
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1024
1025
1026
1027
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1028
1029
1030
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1031
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1032
        if FLASH_ATTENTION:
1033
1034
1035
1036
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1037
1038
1039
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1040
1041
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1042
                trust_remote_code=trust_remote_code,
1043
1044
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1045
1046
1047
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1048

1049
    if model_type == LLAVA_NEXT:
1050
        if FLASH_ATTENTION:
1051
1052
1053
1054
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1055
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1056
                speculator=speculator,
1057
1058
1059
1060
1061
1062
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1063
    if sharded:
1064
        raise NotImplementedError("sharded is not supported for AutoModel")
1065
    if quantize == "gptq":
1066
        raise NotImplementedError(
1067
1068
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1069
    if quantize == "awq":
1070
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1071
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1072
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1073
    elif quantize == "eetq":
1074
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1075
1076
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1077
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1078
        return CausalLM.fallback(
1079
1080
1081
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1082
            speculator=speculator,
1083
1084
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1085
        )
1086
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1087
        return Seq2SeqLM.fallback(
1088
1089
1090
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1091
            speculator=speculator,
1092
1093
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1094
1095
        )

1096
    auto_map = config_dict.get("auto_map", None)
1097
1098
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1099
            return CausalLM.fallback(
1100
1101
1102
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1103
                speculator=speculator,
1104
                dtype=dtype,
1105
1106
                trust_remote_code=trust_remote_code,
            )
1107
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1108
            return Seq2SeqLM.fallback(
1109
1110
1111
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1112
                speculator=speculator,
1113
                dtype=dtype,
1114
1115
                trust_remote_code=trust_remote_code,
            )
1116
1117

    raise ValueError(f"Unsupported model type {model_type}")