__init__.py 49.8 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
35
36
37
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
38
39
40
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
41

42
43
44
45
46
47
48
49
50
51

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


52
from text_generation_server.utils.import_utils import SYSTEM
53
from text_generation_server.utils.log import log_master
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
69
    "get_model_with_lora_adapters",
70
71
]

72
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
73

74
FLASH_ATTENTION = True
75

76
try:
77
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
78
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
Nicolas Patry's avatar
Nicolas Patry committed
79
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
80
81
82
83
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
84
85
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
86
    )
87
88
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
89
    )
90
91
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
92
    )
93
94
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
95
    )
96
97
98
99
100
101
102
103
104
105
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
106
    )
drbh's avatar
drbh committed
107
    from text_generation_server.models.pali_gemma import (
108
        PaliGemmaBatch,
drbh's avatar
drbh committed
109
    )
110
111
112
113
114
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
115
    )
Nicolas Patry's avatar
Nicolas Patry committed
116
117
118
119
120
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
    from text_generation_server.models.custom_modeling.mllama import (
        MllamaForConditionalGeneration,
    )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
143
144
145
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
146
147
148
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
149
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
150
except ImportError as e:
151
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
152
    SUPPORTS_WINDOWING = False
153
    FLASH_ATTENTION = False
154

155
if FLASH_ATTENTION:
156
    __all__.append(FlashCausalLM)
Nicolas Patry's avatar
Nicolas Patry committed
157
    __all__.append(IdeficsCausalLM)
OlivierDehaene's avatar
OlivierDehaene committed
158

drbh's avatar
drbh committed
159
160
161
162
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
163
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
164
165
166
167
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
168

169

170
class ModelType(enum.Enum):
171
172
173
174
175
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
191
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
192
193
194
195
196
197
198
199
200
201
202
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
203
204
205
206
207
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
208
209
210
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
211
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
212
    }
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
231
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
248
249
250
251
252
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
271
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
313
314
315
316
317
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
318
319
320
321
322
323
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }
Nicolas Patry's avatar
Nicolas Patry committed
324
325
326
327
328
329
    MLLAMA = {
        "type": "mllama",
        "name": "Mllama",
        "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
        "multimodal": True,
    }
330
331
332
333
334
335
336


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


337
def get_model(
338
    model_id: str,
drbh's avatar
drbh committed
339
    lora_adapter_ids: Optional[List[str]],
340
341
342
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
343
    speculate: Optional[int],
344
    dtype: Optional[str],
345
    kv_cache_dtype: Optional[str],
346
    trust_remote_code: bool,
347
    max_input_tokens: int,
348
) -> Model:
349
    global FLASH_ATTENTION
350
351
352
353
354
355
356

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
357
    compression_config = config_dict.get("compression_config", None)
358
359
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
360
        config_groups = quantization_config.get("config_groups", None)
361
362
363
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
364
        elif method == "fbgemm_fp8" or method == "fp8":
365
366
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        elif config_groups is not None:
            # TODO: at some point we should probably fully parse the compression
            # configuration to know which parameters are compressed.
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
382
383
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")
384
    elif compression_config is not None:
385
        # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        config_groups = compression_config.get("config_groups")
        if config_groups is not None:
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
400

401
    if dtype is None:
402
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
Nicolas Patry's avatar
Nicolas Patry committed
403
404
405
406
407
            if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
                dtype = torch.bfloat16
            else:
                # These quantizers only work with float16 params.
                dtype = torch.float16
408
        elif quantize == "fp8":
409
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
410

411
            if FBGEMM_DYN_AVAILABLE:
412
413
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
414
415
416
417
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
418
419
420
421
422
423
424
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

425
426
    if kv_cache_dtype is None:
        kv_cache_dtype = dtype
427
428
    elif kv_cache_dtype == "fp8_e4m3fn":
        kv_cache_dtype = torch.float8_e4m3fn
429
430
431
432
433
    elif kv_cache_dtype == "fp8_e5m2":
        kv_cache_dtype = torch.float8_e5m2
    else:
        raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
434
435
436
437
438
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
439
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
440
    if "medusa_num_heads" in config_dict:
441
442
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
443
444
445
446
447
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
448
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
449
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
450
                )
Nicolas Patry's avatar
Nicolas Patry committed
451
452
453
454
455
456
457
458
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
459
460
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
461
462
463
464
465
466
467
468
469
470
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
471
472
473
474
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
475
        else:
Nicolas Patry's avatar
Nicolas Patry committed
476
477
478
479
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
480

Nicolas Patry's avatar
Nicolas Patry committed
481
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
526
527
528
529
530
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
531
532
533
534
535
536
537
538
539
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
540
541
542
543
544
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
545
546
547
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
548

drbh's avatar
drbh committed
549
550
551
552
553
554
555
556
557
558
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

559
560
561
562
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
563
564
565
566
567
568

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
569

570
571
572
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
573
    )
574
575
576
577
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
594
                kv_cache_dtype=kv_cache_dtype,
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
614
615
616
617
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
618
            speculator=speculator,
drbh's avatar
drbh committed
619
620
621
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
622

OlivierDehaene's avatar
OlivierDehaene committed
623
    if model_id.startswith("facebook/galactica"):
624
625
626
627
628
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
629
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
630
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
631
632
            dtype=dtype,
            trust_remote_code=trust_remote_code,
633
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
634
635
        )

636
    if (
637
638
        model_type == GPT_BIGCODE
        or model_type == GPT2
639
640
        and model_id.startswith("bigcode/")
    ):
641
        if FLASH_ATTENTION:
642
643
644
645
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
646
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
647
                speculator=speculator,
648
                dtype=dtype,
649
                kv_cache_dtype=kv_cache_dtype,
650
                trust_remote_code=trust_remote_code,
651
652
653
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
654
            )
655
656
657
658
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
659
        else:
660
661
662
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
663
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
                dtype=dtype,
666
667
                trust_remote_code=trust_remote_code,
            )
668

669
    if model_type == BLOOM:
670
671
672
673
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
674
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
675
            speculator=speculator,
676
677
            dtype=dtype,
            trust_remote_code=trust_remote_code,
678
            batch_class=BloomCausalLMBatch,
679
        )
680
    elif model_type == MPT:
681
682
683
684
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
685
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
686
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
687
688
            dtype=dtype,
            trust_remote_code=trust_remote_code,
689
            batch_class=CausalLMBatchKeysLast,
690
        )
691
    elif model_type == GPT2:
692
        if FLASH_ATTENTION:
693
            try:
694
695
696
697
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
698
699
700
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
701
                    kv_cache_dtype=kv_cache_dtype,
702
                    trust_remote_code=trust_remote_code,
703
                    lora_adapter_ids=lora_adapter_ids,
704
705
706
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
707
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
708
                return CausalLM.fallback(
709
710
711
712
713
714
715
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
716
717
718
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
719
            return CausalLM.fallback(
720
721
722
723
724
725
726
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
727
728
729
730
731
732
733
734
735
736
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
737
                    kv_cache_dtype=kv_cache_dtype,
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
763
    elif model_type == GPT_NEOX:
764
        if FLASH_ATTENTION:
765
766
767
768
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

769
770
771
772
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
773
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
774
                speculator=speculator,
775
                dtype=dtype,
776
                kv_cache_dtype=kv_cache_dtype,
777
                trust_remote_code=trust_remote_code,
778
                lora_adapter_ids=lora_adapter_ids,
779
                config_class=GPTNeoXConfig,
780
781
            )
        elif sharded:
782
783
784
785
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
786
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
787
                speculator=speculator,
788
                dtype=dtype,
789
790
                trust_remote_code=trust_remote_code,
            )
791
        else:
792
            return CausalLM.fallback(
793
794
795
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
796
                speculator=speculator,
797
                dtype=dtype,
798
799
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
800

801
    elif model_type == PHI:
drbh's avatar
drbh committed
802
        if FLASH_ATTENTION:
803
804
805
806
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
807
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
808
                speculator=speculator,
drbh's avatar
drbh committed
809
                dtype=dtype,
810
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
811
                trust_remote_code=trust_remote_code,
812
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
813
814
            )
        else:
815
            return CausalLM.fallback(
drbh's avatar
drbh committed
816
817
818
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
819
                speculator=speculator,
drbh's avatar
drbh committed
820
821
822
823
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
824
825
826
827
828
829
830
831
832
833
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
834
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
835
836
837
838
839
840
841
842
843
844
845
846
847
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
848
849
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
850
851
852
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
853
        else:
854
855
856
857
858
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
859
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
860
                speculator=speculator,
drbh's avatar
drbh committed
861
862
863
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
864

865
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
866
        if FLASH_ATTENTION:
867
868
869
870
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
871
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
872
                speculator=speculator,
873
                dtype=dtype,
874
                kv_cache_dtype=kv_cache_dtype,
875
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
876
                lora_adapter_ids=lora_adapter_ids,
877
            )
878
879
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
880
        else:
881
            return CausalLM.fallback(
882
883
884
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
885
                speculator=speculator,
886
                dtype=dtype,
887
888
                trust_remote_code=trust_remote_code,
            )
889
    if model_type == GEMMA:
890
        if FLASH_ATTENTION:
891
892
893
894
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
895
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
896
                speculator=speculator,
897
                dtype=dtype,
898
                kv_cache_dtype=kv_cache_dtype,
899
900
                # Works better for these models
                default_dtype=torch.bfloat16,
901
                trust_remote_code=trust_remote_code,
902
                lora_adapter_ids=lora_adapter_ids,
903
904
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
905
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
906
        else:
907
            return CausalLM.fallback(
908
909
910
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
911
                speculator=speculator,
912
913
914
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
915
916
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
917
918
919
920
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
921
922
923
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
924
                kv_cache_dtype=kv_cache_dtype,
925
926
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
927
                trust_remote_code=trust_remote_code,
928
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
929
930
931
932
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
933
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
934
935
936
937
938
939
940
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
941

942
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
943
        if FLASH_ATTENTION:
944
945
946
947
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
948
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
949
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
950
                dtype=dtype,
951
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
952
                trust_remote_code=trust_remote_code,
953
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
954
955
956
957
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
958
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
959
960
961
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
962
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
963
964
965
966
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

967
    if model_type == DBRX:
968
        if FLASH_ATTENTION:
969
970
971
972
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
973
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
974
                speculator=speculator,
975
                dtype=dtype,
976
                kv_cache_dtype=kv_cache_dtype,
977
978
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
979
                trust_remote_code=trust_remote_code,
980
981
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
982
983
984
985
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
986
            return CausalLM.fallback(
987
988
989
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
990
                speculator=speculator,
991
992
993
994
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

995
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
996
997
        if sharded:
            if FLASH_ATTENTION:
998
                if config_dict.get("alibi", False):
999
                    raise NotImplementedError("sharded is not supported for this model")
1000
1001
1002
1003
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1004
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1005
                    speculator=speculator,
1006
                    dtype=dtype,
1007
                    kv_cache_dtype=kv_cache_dtype,
1008
1009
1010
1011
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1012
                    trust_remote_code=trust_remote_code,
1013
1014
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1015
                )
1016
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
1017
        else:
1018
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
1019
1020
1021
1022
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1023
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1024
                    speculator=speculator,
1025
                    dtype=dtype,
1026
                    kv_cache_dtype=kv_cache_dtype,
1027
1028
1029
1030
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1031
                    trust_remote_code=trust_remote_code,
1032
1033
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1034
1035
                )
            else:
1036
                return CausalLM.fallback(
1037
1038
1039
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1040
                    speculator=speculator,
1041
                    dtype=dtype,
1042
1043
1044
                    trust_remote_code=trust_remote_code,
                )

1045
    if model_type == MISTRAL:
1046
        if FLASH_ATTENTION:
1047
1048
1049
1050
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
1051
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1052
                speculator=speculator,
1053
                dtype=dtype,
1054
                kv_cache_dtype=kv_cache_dtype,
1055
                trust_remote_code=trust_remote_code,
1056
                lora_adapter_ids=lora_adapter_ids,
1057
            )
OlivierDehaene's avatar
OlivierDehaene committed
1058
1059
1060
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1061
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1062
1063
1064
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1065
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1066
1067
1068
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1069

1070
    if model_type == MIXTRAL:
1071
        if FLASH_ATTENTION:
1072
1073
1074
1075
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1076
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1077
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1078
                dtype=dtype,
1079
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1080
                trust_remote_code=trust_remote_code,
1081
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1082
            )
OlivierDehaene's avatar
OlivierDehaene committed
1083
1084
1085
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1086
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1087
1088
1089
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1090
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1091
1092
1093
1094
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1095
    if model_type == STARCODER2:
1096
        if FLASH_ATTENTION:
1097
1098
1099
1100
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1101
                quantize=quantize,
1102
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1103
                dtype=dtype,
1104
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1105
                trust_remote_code=trust_remote_code,
1106
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1107
1108
1109
1110
1111
1112
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1113
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1114
1115
1116
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1117
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1118
1119
1120
1121
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1122
    if model_type == QWEN2:
1123
        if FLASH_ATTENTION:
1124
1125
1126
1127
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1128
                quantize=quantize,
1129
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1130
                dtype=dtype,
1131
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1132
                trust_remote_code=trust_remote_code,
1133
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1134
1135
1136
1137
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1138
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1139
1140
1141
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1142
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1143
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1144
1145
                trust_remote_code=trust_remote_code,
            )
1146

1147
    if model_type == OPT:
1148
1149
1150
1151
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1152
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1153
            speculator=speculator,
1154
1155
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1156
        )
1157

1158
    if model_type == T5:
1159
1160
1161
1162
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1163
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1164
            speculator=speculator,
1165
            dtype=dtype,
1166
            trust_remote_code=trust_remote_code,
1167
1168
1169
1170
1171
1172
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1173
        )
1174
    if model_type == IDEFICS:
1175
        if FLASH_ATTENTION:
Nicolas Patry's avatar
Nicolas Patry committed
1176
            return IdeficsCausalLM(
OlivierDehaene's avatar
OlivierDehaene committed
1177
1178
1179
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1180
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1181
1182
1183
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1184
1185
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
    if model_type == MLLAMA:
        if FLASH_ATTENTION:
            return MllamaCausalLM(
                model_id=model_id,
                model_class=MllamaForConditionalGeneration,
                batch_class=MllamaCausalLMBatch,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                default_dtype=torch.bfloat16,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
1202
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1203
        if FLASH_ATTENTION:
1204
1205
1206
1207
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1208
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1209
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1210
                dtype=dtype,
1211
                kv_cache_dtype=kv_cache_dtype,
Nicolas Patry's avatar
Nicolas Patry committed
1212
                trust_remote_code=trust_remote_code,
1213
1214
1215
1216
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1217
1218
1219
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1220
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1221
        if FLASH_ATTENTION:
1222
1223
1224
1225
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1226
1227
1228
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1229
                kv_cache_dtype=kv_cache_dtype,
1230
1231
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1232
                trust_remote_code=trust_remote_code,
1233
1234
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1235
1236
1237
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1238

1239
    if model_type == LLAVA_NEXT:
1240
        if FLASH_ATTENTION:
1241
1242
1243
1244
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1245
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1246
                speculator=speculator,
1247
                dtype=dtype,
1248
                kv_cache_dtype=kv_cache_dtype,
1249
1250
1251
1252
1253
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1254
    if sharded:
1255
        raise NotImplementedError("sharded is not supported for AutoModel")
1256
    if quantize == "gptq":
1257
        raise NotImplementedError(
1258
1259
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1260
    if quantize == "awq":
1261
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1262
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1263
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1264
    elif quantize == "eetq":
1265
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1266
1267
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1268
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1269
        return CausalLM.fallback(
1270
1271
1272
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1273
            speculator=speculator,
1274
1275
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1276
        )
1277
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1278
        return Seq2SeqLM.fallback(
1279
1280
1281
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1282
            speculator=speculator,
1283
1284
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1285
1286
        )

1287
    auto_map = config_dict.get("auto_map", None)
1288
1289
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1290
            return CausalLM.fallback(
1291
1292
1293
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1294
                speculator=speculator,
1295
                dtype=dtype,
1296
1297
                trust_remote_code=trust_remote_code,
            )
1298
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1299
            return Seq2SeqLM.fallback(
1300
1301
1302
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1303
                speculator=speculator,
1304
                dtype=dtype,
1305
1306
                trust_remote_code=trust_remote_code,
            )
1307
1308

    raise ValueError(f"Unsupported model type {model_type}")
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
1321
    kv_cache_dtype: Optional[str],
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
1335
        kv_cache_dtype,
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1388
                "qkv_proj",
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1416
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1417
1418
1419
1420
1421
1422
1423
1424
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model