__init__.py 35.8 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
19
from text_generation_server.models.bloom import BloomCausalLMBatch
20
21
22
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
23
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
24
25
26
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
35

36
37
from text_generation_server.utils.import_utils import SYSTEM

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "get_model",
]

57
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
58

59
FLASH_ATTENTION = True
60

61
try:
62
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
63
64
65
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
66
    )
67
68
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
69
    )
70
71
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
72
    )
73
74
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
75
    )
76
77
78
79
80
81
82
83
84
85
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
86
    )
drbh's avatar
drbh committed
87
    from text_generation_server.models.pali_gemma import (
88
        PaliGemmaBatch,
drbh's avatar
drbh committed
89
    )
90
91
92
93
94
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
95
    )
96
    from text_generation_server.models.idefics import IDEFICSSharded
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
122
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
123
124
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
125
    SUPPORTS_WINDOWING = False
126
    FLASH_ATTENTION = False
127

128
if FLASH_ATTENTION:
129
    __all__.append(FlashCausalLM)
130
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
131

drbh's avatar
drbh committed
132
133
134
135
136
137
138
139
140
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
141

142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
171
172
173
174
175
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
234
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


289
def get_model(
290
    model_id: str,
drbh's avatar
drbh committed
291
    lora_adapter_ids: Optional[List[str]],
292
293
294
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
295
    speculate: Optional[int],
296
    dtype: Optional[str],
297
    trust_remote_code: bool,
298
    max_input_tokens: int,
299
) -> Model:
300
    global FLASH_ATTENTION
301
    if dtype is None:
302
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
303
304
305
306
307
308
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
309
310
311
312
313
314
315
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
316
317
318
319
320
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
321
322
323
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
324
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
325

Nicolas Patry's avatar
Nicolas Patry committed
326
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
327
    if "medusa_num_heads" in config_dict:
328
329
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
330
331
332
333
334
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
335
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
336
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
337
                )
Nicolas Patry's avatar
Nicolas Patry committed
338
339
340
341
342
343
344
345
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
346
347
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
348
349
350
351
352
353
354
355
356
357
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
358
359
360
361
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
362
        else:
Nicolas Patry's avatar
Nicolas Patry committed
363
364
365
366
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
367

Nicolas Patry's avatar
Nicolas Patry committed
368
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
422
423
424
425
426
427
428
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
429
430
431
432
433
434
435
436
437
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
438
439
440
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
441
        if method in {"gptq", "awq", "exl2"}:
442
443
444
445
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
446

447
448
449
450
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
451
    sliding_window = config_dict.get("sliding_window", -1)
452
453
454
455
456
457
458
459

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
460
        )
461

462
    if model_type == MAMBA:
drbh's avatar
drbh committed
463
464
465
466
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
467
            speculator=speculator,
drbh's avatar
drbh committed
468
469
470
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
471

OlivierDehaene's avatar
OlivierDehaene committed
472
    if model_id.startswith("facebook/galactica"):
473
474
475
476
477
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
478
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
479
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
480
481
            dtype=dtype,
            trust_remote_code=trust_remote_code,
482
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
483
484
        )

485
    if (
486
487
        model_type == GPT_BIGCODE
        or model_type == GPT2
488
489
        and model_id.startswith("bigcode/")
    ):
490
        if FLASH_ATTENTION:
491
492
493
494
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
495
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
496
                speculator=speculator,
497
                dtype=dtype,
498
                trust_remote_code=trust_remote_code,
499
500
501
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
502
            )
503
504
505
506
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
507
        else:
508
509
510
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
511
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
512
                speculator=speculator,
513
                dtype=dtype,
514
515
                trust_remote_code=trust_remote_code,
            )
516

517
    if model_type == BLOOM:
518
519
520
521
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
522
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
523
            speculator=speculator,
524
525
            dtype=dtype,
            trust_remote_code=trust_remote_code,
526
            batch_class=BloomCausalLMBatch,
527
        )
528
    elif model_type == MPT:
529
530
531
532
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
533
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
534
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
535
536
            dtype=dtype,
            trust_remote_code=trust_remote_code,
537
            batch_class=CausalLMBatchKeysLast,
538
        )
539
    elif model_type == GPT2:
540
        if FLASH_ATTENTION:
541
            try:
542
543
544
545
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
546
547
548
549
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
550
                    lora_adapter_ids=lora_adapter_ids,
551
552
553
554
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
555
                return CausalLM.fallback(
556
557
558
559
560
561
562
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
563
564
565
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
566
            return CausalLM.fallback(
567
568
569
570
571
572
573
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
574
    elif model_type == GPT_NEOX:
575
        if FLASH_ATTENTION:
576
577
578
579
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
580
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
581
                speculator=speculator,
582
                dtype=dtype,
583
                trust_remote_code=trust_remote_code,
584
                lora_adapter_ids=lora_adapter_ids,
585
586
            )
        elif sharded:
587
588
589
590
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
591
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
592
                speculator=speculator,
593
                dtype=dtype,
594
595
                trust_remote_code=trust_remote_code,
            )
596
        else:
597
            return CausalLM.fallback(
598
599
600
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
601
                speculator=speculator,
602
                dtype=dtype,
603
604
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
605

606
    elif model_type == PHI:
drbh's avatar
drbh committed
607
        if FLASH_ATTENTION:
608
609
610
611
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
612
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
613
                speculator=speculator,
drbh's avatar
drbh committed
614
615
                dtype=dtype,
                trust_remote_code=trust_remote_code,
616
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
617
618
            )
        else:
619
            return CausalLM.fallback(
drbh's avatar
drbh committed
620
621
622
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
623
                speculator=speculator,
drbh's avatar
drbh committed
624
625
626
627
628
629
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
630
631
632
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
633
        else:
634
635
636
637
638
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
639
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
640
                speculator=speculator,
drbh's avatar
drbh committed
641
642
643
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
644

645
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
646
        if FLASH_ATTENTION:
647
648
649
650
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
651
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
652
                speculator=speculator,
653
                dtype=dtype,
654
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
655
                lora_adapter_ids=lora_adapter_ids,
656
            )
657
658
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
659
        else:
660
            return CausalLM.fallback(
661
662
663
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
                dtype=dtype,
666
667
                trust_remote_code=trust_remote_code,
            )
668
    if model_type == GEMMA:
669
        if FLASH_ATTENTION:
670
671
672
673
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
674
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
675
                speculator=speculator,
676
                dtype=dtype,
677
678
                # Works better for these models
                default_dtype=torch.bfloat16,
679
                trust_remote_code=trust_remote_code,
680
                lora_adapter_ids=lora_adapter_ids,
681
682
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
683
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
684
        else:
685
            return CausalLM.fallback(
686
687
688
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
689
                speculator=speculator,
690
691
692
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
693
694
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
695
696
697
698
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
699
700
701
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
702
703
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
704
                trust_remote_code=trust_remote_code,
705
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
706
707
708
709
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
710
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
711
712
713
714
715
716
717
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
718

719
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
720
        if FLASH_ATTENTION:
721
722
723
724
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
725
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
726
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
727
728
                dtype=dtype,
                trust_remote_code=trust_remote_code,
729
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
730
731
732
733
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
734
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
735
736
737
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
738
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
739
740
741
742
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

743
    if model_type == DBRX:
744
        if FLASH_ATTENTION:
745
746
747
748
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
749
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
750
                speculator=speculator,
751
                dtype=dtype,
752
753
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
754
                trust_remote_code=trust_remote_code,
755
756
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
757
758
759
760
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
761
            return CausalLM.fallback(
762
763
764
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
765
                speculator=speculator,
766
767
768
769
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

770
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
771
772
        if sharded:
            if FLASH_ATTENTION:
773
                if config_dict.get("alibi", False):
774
                    raise NotImplementedError("sharded is not supported for this model")
775
776
777
778
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
779
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
780
                    speculator=speculator,
781
                    dtype=dtype,
782
783
784
785
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
786
                    trust_remote_code=trust_remote_code,
787
788
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
789
                )
790
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
791
        else:
792
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
793
794
795
796
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
797
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
798
                    speculator=speculator,
799
                    dtype=dtype,
800
                    trust_remote_code=trust_remote_code,
801
802
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
803
804
                )
            else:
805
                return CausalLM.fallback(
806
807
808
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
809
                    speculator=speculator,
810
                    dtype=dtype,
811
812
813
                    trust_remote_code=trust_remote_code,
                )

814
    if model_type == MISTRAL:
815
        if FLASH_ATTENTION:
816
817
818
819
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
820
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
821
                speculator=speculator,
822
823
                dtype=dtype,
                trust_remote_code=trust_remote_code,
824
                lora_adapter_ids=lora_adapter_ids,
825
            )
OlivierDehaene's avatar
OlivierDehaene committed
826
827
828
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
829
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
830
831
832
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
833
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
834
835
836
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
837

838
    if model_type == MIXTRAL:
839
        if FLASH_ATTENTION:
840
841
842
843
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
844
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
845
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
846
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
847
                trust_remote_code=trust_remote_code,
848
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
849
            )
OlivierDehaene's avatar
OlivierDehaene committed
850
851
852
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
853
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
854
855
856
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
857
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
858
859
860
861
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

862
    if model_type == STARCODER2:
863
        if FLASH_ATTENTION:
864
865
866
867
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
868
                quantize=quantize,
869
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
870
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
871
                trust_remote_code=trust_remote_code,
872
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
873
874
875
876
877
878
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
879
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
880
881
882
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
883
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
884
885
886
887
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

888
    if model_type == QWEN2:
889
        if FLASH_ATTENTION:
890
891
892
893
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
894
                quantize=quantize,
895
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
896
897
                dtype=dtype,
                trust_remote_code=trust_remote_code,
898
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
899
900
901
902
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
903
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
904
905
906
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
907
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
908
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
909
910
                trust_remote_code=trust_remote_code,
            )
911

912
    if model_type == OPT:
913
914
915
916
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
917
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
918
            speculator=speculator,
919
920
            dtype=dtype,
            trust_remote_code=trust_remote_code,
921
        )
922

923
    if model_type == T5:
924
925
926
927
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
928
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
929
            speculator=speculator,
930
            dtype=dtype,
931
            trust_remote_code=trust_remote_code,
932
933
934
935
936
937
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
938
        )
939
    if model_type == IDEFICS:
940
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
941
942
943
944
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
945
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
946
947
948
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
949
950
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
951
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
952
        if FLASH_ATTENTION:
953
954
955
956
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
957
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
958
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
959
960
                dtype=dtype,
                trust_remote_code=trust_remote_code,
961
962
963
964
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
965
966
967
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
968
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
969
        if FLASH_ATTENTION:
970
971
972
973
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
974
975
976
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
977
978
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
979
                trust_remote_code=trust_remote_code,
980
981
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
982
983
984
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
985

986
    if model_type == LLAVA_NEXT:
987
        if FLASH_ATTENTION:
988
989
990
991
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
992
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
993
                speculator=speculator,
994
995
996
997
998
999
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1000
    if sharded:
1001
        raise NotImplementedError("sharded is not supported for AutoModel")
1002
    if quantize == "gptq":
1003
        raise NotImplementedError(
1004
1005
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1006
    if quantize == "awq":
1007
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1008
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1009
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1010
    elif quantize == "eetq":
1011
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1012
1013
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1014
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1015
        return CausalLM.fallback(
1016
1017
1018
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1019
            speculator=speculator,
1020
1021
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1022
        )
1023
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1024
        return Seq2SeqLM.fallback(
1025
1026
1027
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1028
            speculator=speculator,
1029
1030
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1031
1032
        )

1033
    auto_map = config_dict.get("auto_map", None)
1034
1035
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1036
            return CausalLM.fallback(
1037
1038
1039
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1040
                speculator=speculator,
1041
                dtype=dtype,
1042
1043
                trust_remote_code=trust_remote_code,
            )
1044
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1045
            return Seq2SeqLM.fallback(
1046
1047
1048
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1049
                speculator=speculator,
1050
                dtype=dtype,
1051
1052
                trust_remote_code=trust_remote_code,
            )
1053
1054

    raise ValueError(f"Unsupported model type {model_type}")