__init__.py 35.8 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
22
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
23
24
25
26
27
28
29
30
31
32
33
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
34

35
36
from text_generation_server.utils.import_utils import SYSTEM

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "get_model",
]

56
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
57

58
FLASH_ATTENTION = True
59

60
try:
61
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
62
63
64
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
65
    )
66
67
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
68
    )
69
70
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
71
    )
72
73
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
74
    )
75
76
77
78
79
80
81
82
83
84
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
85
    )
drbh's avatar
drbh committed
86
    from text_generation_server.models.pali_gemma import (
87
        PaliGemmaBatch,
drbh's avatar
drbh committed
88
    )
89
90
91
92
93
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
94
    )
95
    from text_generation_server.models.idefics import IDEFICSSharded
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
121
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
122
123
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
124
    SUPPORTS_WINDOWING = False
125
    FLASH_ATTENTION = False
126

127
if FLASH_ATTENTION:
128
    __all__.append(FlashCausalLM)
129
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
130

drbh's avatar
drbh committed
131
132
133
134
135
136
137
138
139
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
140

141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
170
171
172
173
174
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
175
176
177
178
179
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
233
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


288
def get_model(
289
    model_id: str,
drbh's avatar
drbh committed
290
    lora_adapter_ids: Optional[List[str]],
291
292
293
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
294
    speculate: Optional[int],
295
    dtype: Optional[str],
296
    trust_remote_code: bool,
297
    max_input_tokens: int,
298
) -> Model:
299
    global FLASH_ATTENTION
300
    if dtype is None:
301
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
302
303
304
305
306
307
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
308
309
310
311
312
313
314
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
315
316
317
318
319
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
320
321
322
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
323
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
324

Nicolas Patry's avatar
Nicolas Patry committed
325
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
326
    if "medusa_num_heads" in config_dict:
327
328
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
329
330
331
332
333
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
334
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
335
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
336
                )
Nicolas Patry's avatar
Nicolas Patry committed
337
338
339
340
341
342
343
344
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
345
346
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
347
348
349
350
351
352
353
354
355
356
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
357
358
359
360
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
361
        else:
Nicolas Patry's avatar
Nicolas Patry committed
362
363
364
365
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
366

Nicolas Patry's avatar
Nicolas Patry committed
367
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
421
422
423
424
425
426
427
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
428
429
430
431
432
433
434
435
436
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
437
438
439
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
440
        if method in {"gptq", "awq", "exl2"}:
441
442
443
444
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
445

446
447
448
449
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
450
    sliding_window = config_dict.get("sliding_window", -1)
451
452
453
454
455
456
457
458

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
459
        )
460

461
    if model_type == MAMBA:
drbh's avatar
drbh committed
462
463
464
465
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
466
            speculator=speculator,
drbh's avatar
drbh committed
467
468
469
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
470

OlivierDehaene's avatar
OlivierDehaene committed
471
    if model_id.startswith("facebook/galactica"):
472
473
474
475
476
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
477
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
478
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
479
480
            dtype=dtype,
            trust_remote_code=trust_remote_code,
481
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
482
483
        )

484
    if (
485
486
        model_type == GPT_BIGCODE
        or model_type == GPT2
487
488
        and model_id.startswith("bigcode/")
    ):
489
        if FLASH_ATTENTION:
490
491
492
493
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
494
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
495
                speculator=speculator,
496
                dtype=dtype,
497
                trust_remote_code=trust_remote_code,
498
499
500
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
501
            )
502
503
504
505
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
506
        else:
507
508
509
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
510
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
511
                speculator=speculator,
512
                dtype=dtype,
513
514
                trust_remote_code=trust_remote_code,
            )
515

516
    if model_type == BLOOM:
517
518
519
520
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
521
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
522
            speculator=speculator,
523
524
            dtype=dtype,
            trust_remote_code=trust_remote_code,
525
            batch_class=CausalLMBatchKeysLast,
526
        )
527
    elif model_type == MPT:
528
529
530
531
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
532
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
533
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
534
535
            dtype=dtype,
            trust_remote_code=trust_remote_code,
536
            batch_class=CausalLMBatchKeysLast,
537
        )
538
    elif model_type == GPT2:
539
        if FLASH_ATTENTION:
540
            try:
541
542
543
544
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
545
546
547
548
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
549
                    lora_adapter_ids=lora_adapter_ids,
550
551
552
553
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
554
                return CausalLM.fallback(
555
556
557
558
559
560
561
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
562
563
564
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
565
            return CausalLM.fallback(
566
567
568
569
570
571
572
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
573
    elif model_type == GPT_NEOX:
574
        if FLASH_ATTENTION:
575
576
577
578
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
579
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
580
                speculator=speculator,
581
                dtype=dtype,
582
                trust_remote_code=trust_remote_code,
583
                lora_adapter_ids=lora_adapter_ids,
584
585
            )
        elif sharded:
586
587
588
589
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
590
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
591
                speculator=speculator,
592
                dtype=dtype,
593
594
                trust_remote_code=trust_remote_code,
            )
595
        else:
596
            return CausalLM.fallback(
597
598
599
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
600
                speculator=speculator,
601
                dtype=dtype,
602
603
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
604

605
    elif model_type == PHI:
drbh's avatar
drbh committed
606
        if FLASH_ATTENTION:
607
608
609
610
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
611
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
612
                speculator=speculator,
drbh's avatar
drbh committed
613
614
                dtype=dtype,
                trust_remote_code=trust_remote_code,
615
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
616
617
            )
        else:
618
            return CausalLM.fallback(
drbh's avatar
drbh committed
619
620
621
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
622
                speculator=speculator,
drbh's avatar
drbh committed
623
624
625
626
627
628
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
629
630
631
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
632
        else:
633
634
635
636
637
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
638
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
drbh's avatar
drbh committed
640
641
642
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
643

644
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
645
        if FLASH_ATTENTION:
646
647
648
649
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
650
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
651
                speculator=speculator,
652
                dtype=dtype,
653
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
654
                lora_adapter_ids=lora_adapter_ids,
655
            )
656
657
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
658
        else:
659
            return CausalLM.fallback(
660
661
662
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
663
                speculator=speculator,
664
                dtype=dtype,
665
666
                trust_remote_code=trust_remote_code,
            )
667
    if model_type == GEMMA:
668
        if FLASH_ATTENTION:
669
670
671
672
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
673
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
674
                speculator=speculator,
675
                dtype=dtype,
676
677
                # Works better for these models
                default_dtype=torch.bfloat16,
678
                trust_remote_code=trust_remote_code,
679
                lora_adapter_ids=lora_adapter_ids,
680
681
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
682
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
683
        else:
684
            return CausalLM.fallback(
685
686
687
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
688
                speculator=speculator,
689
690
691
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
692
693
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
694
695
696
697
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
698
699
700
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
701
702
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
703
                trust_remote_code=trust_remote_code,
704
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
705
706
707
708
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
709
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
710
711
712
713
714
715
716
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
717

718
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
719
        if FLASH_ATTENTION:
720
721
722
723
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
724
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
725
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
726
727
                dtype=dtype,
                trust_remote_code=trust_remote_code,
728
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
729
730
731
732
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
733
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
734
735
736
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
737
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
738
739
740
741
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

742
    if model_type == DBRX:
743
        if FLASH_ATTENTION:
744
745
746
747
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
748
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
749
                speculator=speculator,
750
                dtype=dtype,
751
752
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
753
                trust_remote_code=trust_remote_code,
754
755
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
756
757
758
759
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
760
            return CausalLM.fallback(
761
762
763
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
764
                speculator=speculator,
765
766
767
768
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

769
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
770
771
        if sharded:
            if FLASH_ATTENTION:
772
                if config_dict.get("alibi", False):
773
                    raise NotImplementedError("sharded is not supported for this model")
774
775
776
777
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
778
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
779
                    speculator=speculator,
780
                    dtype=dtype,
781
782
783
784
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
785
                    trust_remote_code=trust_remote_code,
786
787
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
788
                )
789
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
790
        else:
791
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
792
793
794
795
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
796
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
797
                    speculator=speculator,
798
                    dtype=dtype,
799
                    trust_remote_code=trust_remote_code,
800
801
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
802
803
                )
            else:
804
                return CausalLM.fallback(
805
806
807
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
808
                    speculator=speculator,
809
                    dtype=dtype,
810
811
812
                    trust_remote_code=trust_remote_code,
                )

813
    if model_type == MISTRAL:
814
        if FLASH_ATTENTION:
815
816
817
818
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
819
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
820
                speculator=speculator,
821
822
                dtype=dtype,
                trust_remote_code=trust_remote_code,
823
                lora_adapter_ids=lora_adapter_ids,
824
            )
OlivierDehaene's avatar
OlivierDehaene committed
825
826
827
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
828
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
829
830
831
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
832
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
833
834
835
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
836

837
    if model_type == MIXTRAL:
838
        if FLASH_ATTENTION:
839
840
841
842
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
843
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
844
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
845
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
846
                trust_remote_code=trust_remote_code,
847
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
848
            )
OlivierDehaene's avatar
OlivierDehaene committed
849
850
851
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
852
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
853
854
855
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
856
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
857
858
859
860
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

861
    if model_type == STARCODER2:
862
        if FLASH_ATTENTION:
863
864
865
866
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
867
                quantize=quantize,
868
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
869
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
870
                trust_remote_code=trust_remote_code,
871
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
872
873
874
875
876
877
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
878
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
879
880
881
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
882
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
883
884
885
886
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

887
    if model_type == QWEN2:
888
        if FLASH_ATTENTION:
889
890
891
892
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
893
                quantize=quantize,
894
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
895
896
                dtype=dtype,
                trust_remote_code=trust_remote_code,
897
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
898
899
900
901
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
902
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
903
904
905
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
906
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
907
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
908
909
                trust_remote_code=trust_remote_code,
            )
910

911
    if model_type == OPT:
912
913
914
915
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
916
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
917
            speculator=speculator,
918
919
            dtype=dtype,
            trust_remote_code=trust_remote_code,
920
        )
921

922
    if model_type == T5:
923
924
925
926
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
927
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
928
            speculator=speculator,
929
            dtype=dtype,
930
            trust_remote_code=trust_remote_code,
931
932
933
934
935
936
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
937
        )
938
    if model_type == IDEFICS:
939
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
940
941
942
943
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
944
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
945
946
947
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
948
949
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
950
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
951
        if FLASH_ATTENTION:
952
953
954
955
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
956
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
957
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
958
959
                dtype=dtype,
                trust_remote_code=trust_remote_code,
960
961
962
963
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
964
965
966
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
967
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
968
        if FLASH_ATTENTION:
969
970
971
972
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
973
974
975
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
976
977
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
978
                trust_remote_code=trust_remote_code,
979
980
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
981
982
983
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
984

985
    if model_type == LLAVA_NEXT:
986
        if FLASH_ATTENTION:
987
988
989
990
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
991
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
992
                speculator=speculator,
993
994
995
996
997
998
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

999
    if sharded:
1000
        raise NotImplementedError("sharded is not supported for AutoModel")
1001
    if quantize == "gptq":
1002
        raise NotImplementedError(
1003
1004
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1005
    if quantize == "awq":
1006
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1007
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1008
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1009
    elif quantize == "eetq":
1010
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1011
1012
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1013
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1014
        return CausalLM.fallback(
1015
1016
1017
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1018
            speculator=speculator,
1019
1020
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1021
        )
1022
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1023
        return Seq2SeqLM.fallback(
1024
1025
1026
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1027
            speculator=speculator,
1028
1029
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1030
1031
        )

1032
    auto_map = config_dict.get("auto_map", None)
1033
1034
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1035
            return CausalLM.fallback(
1036
1037
1038
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1039
                speculator=speculator,
1040
                dtype=dtype,
1041
1042
                trust_remote_code=trust_remote_code,
            )
1043
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1044
            return Seq2SeqLM.fallback(
1045
1046
1047
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1048
                speculator=speculator,
1049
                dtype=dtype,
1050
1051
                trust_remote_code=trust_remote_code,
            )
1052
1053

    raise ValueError(f"Unsupported model type {model_type}")