__init__.py 49.6 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
35
36
37
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
38
39
40
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
41

42
43
44
45
46
47
48
49
50
51

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


52
from text_generation_server.utils.import_utils import SYSTEM
53
from text_generation_server.utils.log import log_master
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
69
    "get_model_with_lora_adapters",
70
71
]

72
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
73

74
FLASH_ATTENTION = True
75

76
try:
77
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
78
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
Nicolas Patry's avatar
Nicolas Patry committed
79
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
80
81
82
83
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
84
85
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
86
    )
87
88
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
89
    )
90
91
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
92
    )
93
94
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
95
    )
96
97
98
99
100
101
102
103
104
105
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
106
    )
drbh's avatar
drbh committed
107
    from text_generation_server.models.pali_gemma import (
108
        PaliGemmaBatch,
drbh's avatar
drbh committed
109
    )
110
111
112
113
114
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
115
    )
Nicolas Patry's avatar
Nicolas Patry committed
116
117
118
119
120
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
    from text_generation_server.models.custom_modeling.mllama import (
        MllamaForConditionalGeneration,
    )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
143
144
145
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
146
147
148
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
149
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
150
except ImportError as e:
151
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
152
    SUPPORTS_WINDOWING = False
153
    FLASH_ATTENTION = False
154

155
if FLASH_ATTENTION:
156
    __all__.append(FlashCausalLM)
Nicolas Patry's avatar
Nicolas Patry committed
157
    __all__.append(IdeficsCausalLM)
OlivierDehaene's avatar
OlivierDehaene committed
158

drbh's avatar
drbh committed
159
160
161
162
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
163
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
164
165
166
167
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
168

169

170
class ModelType(enum.Enum):
171
172
173
174
175
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
191
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
192
193
194
195
196
197
198
199
200
201
202
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
203
204
205
206
207
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
208
209
210
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
211
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
212
    }
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
231
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
248
249
250
251
252
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
271
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
313
314
315
316
317
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
318
319
320
321
322
323
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }
Nicolas Patry's avatar
Nicolas Patry committed
324
325
326
327
328
329
    MLLAMA = {
        "type": "mllama",
        "name": "Mllama",
        "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
        "multimodal": True,
    }
330
331
332
333
334
335
336


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


337
def get_model(
338
    model_id: str,
drbh's avatar
drbh committed
339
    lora_adapter_ids: Optional[List[str]],
340
341
342
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
343
    speculate: Optional[int],
344
    dtype: Optional[str],
345
    kv_cache_dtype: Optional[str],
346
    trust_remote_code: bool,
347
    max_input_tokens: int,
348
) -> Model:
349
    global FLASH_ATTENTION
350
351
352
353
354
355
356

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
357
    compression_config = config_dict.get("compression_config", None)
358
359
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
360
        config_groups = quantization_config.get("config_groups", None)
361
362
363
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
364
        elif method == "fbgemm_fp8" or method == "fp8":
365
366
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        elif config_groups is not None:
            # TODO: at some point we should probably fully parse the compression
            # configuration to know which parameters are compressed.
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
382
383
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")
384
    elif compression_config is not None:
385
        # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        config_groups = compression_config.get("config_groups")
        if config_groups is not None:
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
400

401
    if dtype is None:
402
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
403
404
            # These quantizers only work with float16 params.
            dtype = torch.float16
405
        elif quantize == "fp8":
406
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
407

408
            if FBGEMM_DYN_AVAILABLE:
409
410
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
411
412
413
414
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
415
416
417
418
419
420
421
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

422
423
    if kv_cache_dtype is None:
        kv_cache_dtype = dtype
424
425
    elif kv_cache_dtype == "fp8_e4m3fn":
        kv_cache_dtype = torch.float8_e4m3fn
426
427
428
429
430
    elif kv_cache_dtype == "fp8_e5m2":
        kv_cache_dtype = torch.float8_e5m2
    else:
        raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
431
432
433
434
435
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
436
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
437
    if "medusa_num_heads" in config_dict:
438
439
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
440
441
442
443
444
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
445
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
446
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
447
                )
Nicolas Patry's avatar
Nicolas Patry committed
448
449
450
451
452
453
454
455
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
456
457
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
458
459
460
461
462
463
464
465
466
467
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
468
469
470
471
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
472
        else:
Nicolas Patry's avatar
Nicolas Patry committed
473
474
475
476
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
477

Nicolas Patry's avatar
Nicolas Patry committed
478
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
523
524
525
526
527
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
528
529
530
531
532
533
534
535
536
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
537
538
539
540
541
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
542
543
544
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
545

drbh's avatar
drbh committed
546
547
548
549
550
551
552
553
554
555
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

556
557
558
559
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
560
561
562
563
564
565

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
566

567
568
569
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
570
    )
571
572
573
574
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
591
                kv_cache_dtype=kv_cache_dtype,
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
611
612
613
614
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
615
            speculator=speculator,
drbh's avatar
drbh committed
616
617
618
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
619

OlivierDehaene's avatar
OlivierDehaene committed
620
    if model_id.startswith("facebook/galactica"):
621
622
623
624
625
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
626
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
627
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
628
629
            dtype=dtype,
            trust_remote_code=trust_remote_code,
630
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
631
632
        )

633
    if (
634
635
        model_type == GPT_BIGCODE
        or model_type == GPT2
636
637
        and model_id.startswith("bigcode/")
    ):
638
        if FLASH_ATTENTION:
639
640
641
642
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
643
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
644
                speculator=speculator,
645
                dtype=dtype,
646
                kv_cache_dtype=kv_cache_dtype,
647
                trust_remote_code=trust_remote_code,
648
649
650
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
651
            )
652
653
654
655
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
656
        else:
657
658
659
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
660
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
661
                speculator=speculator,
662
                dtype=dtype,
663
664
                trust_remote_code=trust_remote_code,
            )
665

666
    if model_type == BLOOM:
667
668
669
670
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
671
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
672
            speculator=speculator,
673
674
            dtype=dtype,
            trust_remote_code=trust_remote_code,
675
            batch_class=BloomCausalLMBatch,
676
        )
677
    elif model_type == MPT:
678
679
680
681
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
682
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
683
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
684
685
            dtype=dtype,
            trust_remote_code=trust_remote_code,
686
            batch_class=CausalLMBatchKeysLast,
687
        )
688
    elif model_type == GPT2:
689
        if FLASH_ATTENTION:
690
            try:
691
692
693
694
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
695
696
697
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
698
                    kv_cache_dtype=kv_cache_dtype,
699
                    trust_remote_code=trust_remote_code,
700
                    lora_adapter_ids=lora_adapter_ids,
701
702
703
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
704
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
705
                return CausalLM.fallback(
706
707
708
709
710
711
712
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
713
714
715
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
716
            return CausalLM.fallback(
717
718
719
720
721
722
723
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
724
725
726
727
728
729
730
731
732
733
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
734
                    kv_cache_dtype=kv_cache_dtype,
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
760
    elif model_type == GPT_NEOX:
761
        if FLASH_ATTENTION:
762
763
764
765
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

766
767
768
769
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
770
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
771
                speculator=speculator,
772
                dtype=dtype,
773
                kv_cache_dtype=kv_cache_dtype,
774
                trust_remote_code=trust_remote_code,
775
                lora_adapter_ids=lora_adapter_ids,
776
                config_class=GPTNeoXConfig,
777
778
            )
        elif sharded:
779
780
781
782
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
783
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
784
                speculator=speculator,
785
                dtype=dtype,
786
787
                trust_remote_code=trust_remote_code,
            )
788
        else:
789
            return CausalLM.fallback(
790
791
792
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
793
                speculator=speculator,
794
                dtype=dtype,
795
796
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
797

798
    elif model_type == PHI:
drbh's avatar
drbh committed
799
        if FLASH_ATTENTION:
800
801
802
803
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
804
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
805
                speculator=speculator,
drbh's avatar
drbh committed
806
                dtype=dtype,
807
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
808
                trust_remote_code=trust_remote_code,
809
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
810
811
            )
        else:
812
            return CausalLM.fallback(
drbh's avatar
drbh committed
813
814
815
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
816
                speculator=speculator,
drbh's avatar
drbh committed
817
818
819
820
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
821
822
823
824
825
826
827
828
829
830
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
831
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
832
833
834
835
836
837
838
839
840
841
842
843
844
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
845
846
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
847
848
849
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
850
        else:
851
852
853
854
855
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
856
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
857
                speculator=speculator,
drbh's avatar
drbh committed
858
859
860
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
861

862
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
863
        if FLASH_ATTENTION:
864
865
866
867
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
868
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
869
                speculator=speculator,
870
                dtype=dtype,
871
                kv_cache_dtype=kv_cache_dtype,
872
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
873
                lora_adapter_ids=lora_adapter_ids,
874
            )
875
876
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
877
        else:
878
            return CausalLM.fallback(
879
880
881
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
882
                speculator=speculator,
883
                dtype=dtype,
884
885
                trust_remote_code=trust_remote_code,
            )
886
    if model_type == GEMMA:
887
        if FLASH_ATTENTION:
888
889
890
891
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
892
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
893
                speculator=speculator,
894
                dtype=dtype,
895
                kv_cache_dtype=kv_cache_dtype,
896
897
                # Works better for these models
                default_dtype=torch.bfloat16,
898
                trust_remote_code=trust_remote_code,
899
                lora_adapter_ids=lora_adapter_ids,
900
901
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
902
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
903
        else:
904
            return CausalLM.fallback(
905
906
907
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
908
                speculator=speculator,
909
910
911
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
912
913
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
914
915
916
917
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
918
919
920
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
921
                kv_cache_dtype=kv_cache_dtype,
922
923
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
924
                trust_remote_code=trust_remote_code,
925
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
926
927
928
929
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
930
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
931
932
933
934
935
936
937
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
938

939
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
940
        if FLASH_ATTENTION:
941
942
943
944
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
945
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
946
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
947
                dtype=dtype,
948
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
949
                trust_remote_code=trust_remote_code,
950
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
951
952
953
954
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
955
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
956
957
958
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
959
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
960
961
962
963
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

964
    if model_type == DBRX:
965
        if FLASH_ATTENTION:
966
967
968
969
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
970
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
971
                speculator=speculator,
972
                dtype=dtype,
973
                kv_cache_dtype=kv_cache_dtype,
974
975
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
976
                trust_remote_code=trust_remote_code,
977
978
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
979
980
981
982
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
983
            return CausalLM.fallback(
984
985
986
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
987
                speculator=speculator,
988
989
990
991
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

992
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
993
994
        if sharded:
            if FLASH_ATTENTION:
995
                if config_dict.get("alibi", False):
996
                    raise NotImplementedError("sharded is not supported for this model")
997
998
999
1000
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1001
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1002
                    speculator=speculator,
1003
                    dtype=dtype,
1004
                    kv_cache_dtype=kv_cache_dtype,
1005
1006
1007
1008
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1009
                    trust_remote_code=trust_remote_code,
1010
1011
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1012
                )
1013
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
1014
        else:
1015
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
1016
1017
1018
1019
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1020
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1021
                    speculator=speculator,
1022
                    dtype=dtype,
1023
                    kv_cache_dtype=kv_cache_dtype,
1024
1025
1026
1027
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1028
                    trust_remote_code=trust_remote_code,
1029
1030
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1031
1032
                )
            else:
1033
                return CausalLM.fallback(
1034
1035
1036
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1037
                    speculator=speculator,
1038
                    dtype=dtype,
1039
1040
1041
                    trust_remote_code=trust_remote_code,
                )

1042
    if model_type == MISTRAL:
1043
        if FLASH_ATTENTION:
1044
1045
1046
1047
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
1048
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1049
                speculator=speculator,
1050
                dtype=dtype,
1051
                kv_cache_dtype=kv_cache_dtype,
1052
                trust_remote_code=trust_remote_code,
1053
                lora_adapter_ids=lora_adapter_ids,
1054
            )
OlivierDehaene's avatar
OlivierDehaene committed
1055
1056
1057
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1058
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1059
1060
1061
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1062
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1063
1064
1065
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1066

1067
    if model_type == MIXTRAL:
1068
        if FLASH_ATTENTION:
1069
1070
1071
1072
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1073
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1074
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1075
                dtype=dtype,
1076
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1077
                trust_remote_code=trust_remote_code,
1078
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1079
            )
OlivierDehaene's avatar
OlivierDehaene committed
1080
1081
1082
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1083
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1084
1085
1086
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1087
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1088
1089
1090
1091
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1092
    if model_type == STARCODER2:
1093
        if FLASH_ATTENTION:
1094
1095
1096
1097
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1098
                quantize=quantize,
1099
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1100
                dtype=dtype,
1101
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1102
                trust_remote_code=trust_remote_code,
1103
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1104
1105
1106
1107
1108
1109
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1110
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1111
1112
1113
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1114
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1115
1116
1117
1118
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1119
    if model_type == QWEN2:
1120
        if FLASH_ATTENTION:
1121
1122
1123
1124
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1125
                quantize=quantize,
1126
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1127
                dtype=dtype,
1128
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1129
                trust_remote_code=trust_remote_code,
1130
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1131
1132
1133
1134
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1135
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1136
1137
1138
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1139
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1140
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1141
1142
                trust_remote_code=trust_remote_code,
            )
1143

1144
    if model_type == OPT:
1145
1146
1147
1148
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1149
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1150
            speculator=speculator,
1151
1152
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1153
        )
1154

1155
    if model_type == T5:
1156
1157
1158
1159
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1160
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1161
            speculator=speculator,
1162
            dtype=dtype,
1163
            trust_remote_code=trust_remote_code,
1164
1165
1166
1167
1168
1169
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1170
        )
1171
    if model_type == IDEFICS:
1172
        if FLASH_ATTENTION:
Nicolas Patry's avatar
Nicolas Patry committed
1173
            return IdeficsCausalLM(
OlivierDehaene's avatar
OlivierDehaene committed
1174
1175
1176
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1177
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1178
1179
1180
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1181
1182
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    if model_type == MLLAMA:
        if FLASH_ATTENTION:
            return MllamaCausalLM(
                model_id=model_id,
                model_class=MllamaForConditionalGeneration,
                batch_class=MllamaCausalLMBatch,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                default_dtype=torch.bfloat16,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
1199
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1200
        if FLASH_ATTENTION:
1201
1202
1203
1204
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1205
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1206
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1207
                dtype=dtype,
1208
                kv_cache_dtype=kv_cache_dtype,
Nicolas Patry's avatar
Nicolas Patry committed
1209
                trust_remote_code=trust_remote_code,
1210
1211
1212
1213
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1214
1215
1216
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1217
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1218
        if FLASH_ATTENTION:
1219
1220
1221
1222
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1223
1224
1225
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1226
                kv_cache_dtype=kv_cache_dtype,
1227
1228
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1229
                trust_remote_code=trust_remote_code,
1230
1231
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1232
1233
1234
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1235

1236
    if model_type == LLAVA_NEXT:
1237
        if FLASH_ATTENTION:
1238
1239
1240
1241
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1242
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1243
                speculator=speculator,
1244
                dtype=dtype,
1245
                kv_cache_dtype=kv_cache_dtype,
1246
1247
1248
1249
1250
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1251
    if sharded:
1252
        raise NotImplementedError("sharded is not supported for AutoModel")
1253
    if quantize == "gptq":
1254
        raise NotImplementedError(
1255
1256
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1257
    if quantize == "awq":
1258
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1259
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1260
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1261
    elif quantize == "eetq":
1262
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1263
1264
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1265
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1266
        return CausalLM.fallback(
1267
1268
1269
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1270
            speculator=speculator,
1271
1272
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1273
        )
1274
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1275
        return Seq2SeqLM.fallback(
1276
1277
1278
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1279
            speculator=speculator,
1280
1281
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1282
1283
        )

1284
    auto_map = config_dict.get("auto_map", None)
1285
1286
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1287
            return CausalLM.fallback(
1288
1289
1290
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1291
                speculator=speculator,
1292
                dtype=dtype,
1293
1294
                trust_remote_code=trust_remote_code,
            )
1295
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1296
            return Seq2SeqLM.fallback(
1297
1298
1299
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1300
                speculator=speculator,
1301
                dtype=dtype,
1302
1303
                trust_remote_code=trust_remote_code,
            )
1304
1305

    raise ValueError(f"Unsupported model type {model_type}")
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
1318
    kv_cache_dtype: Optional[str],
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
1332
        kv_cache_dtype,
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1385
                "qkv_proj",
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1413
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1414
1415
1416
1417
1418
1419
1420
1421
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model