__init__.py 42.6 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
35
36
37
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
38

39
40
41
42
43
44
45
46
47
48

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


49
from text_generation_server.utils.import_utils import SYSTEM
50
from text_generation_server.utils.log import log_master
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
66
    "get_model_with_lora_adapters",
67
68
]

69
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
70

71
FLASH_ATTENTION = True
72

73
try:
74
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
75
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
76
77
78
79
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
80
81
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
91
    )
92
93
94
95
96
97
98
99
100
101
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
102
    )
drbh's avatar
drbh committed
103
    from text_generation_server.models.pali_gemma import (
104
        PaliGemmaBatch,
drbh's avatar
drbh committed
105
    )
106
107
108
109
110
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
111
    )
112
    from text_generation_server.models.idefics import IDEFICSSharded
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
138
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
139
except ImportError as e:
140
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
141
    SUPPORTS_WINDOWING = False
142
    FLASH_ATTENTION = False
143

144
if FLASH_ATTENTION:
145
    __all__.append(FlashCausalLM)
146
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
147

drbh's avatar
drbh committed
148
149
150
151
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
152
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
153
154
155
156
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
157

158

159
class ModelType(enum.Enum):
160
161
162
163
164
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
192
193
194
195
196
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
197
198
199
200
201
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
255
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


310
def get_model(
311
    model_id: str,
drbh's avatar
drbh committed
312
    lora_adapter_ids: Optional[List[str]],
313
314
315
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
316
    speculate: Optional[int],
317
    dtype: Optional[str],
318
    trust_remote_code: bool,
319
    max_input_tokens: int,
320
) -> Model:
321
    global FLASH_ATTENTION
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

340
    if dtype is None:
341
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
342
343
            # These quantizers only work with float16 params.
            dtype = torch.float16
344
        elif quantize == "fp8":
345
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
346

347
            if FBGEMM_DYN_AVAILABLE:
348
349
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
350
351
352
353
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
354
355
356
357
358
359
360
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
361
362
363
364
365
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
366
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
367
    if "medusa_num_heads" in config_dict:
368
369
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
370
371
372
373
374
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
375
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
376
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
377
                )
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
383
384
385
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
386
387
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
388
389
390
391
392
393
394
395
396
397
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
398
399
400
401
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
402
        else:
Nicolas Patry's avatar
Nicolas Patry committed
403
404
405
406
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
407

Nicolas Patry's avatar
Nicolas Patry committed
408
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
462
463
464
465
466
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
467
468
469
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
470

drbh's avatar
drbh committed
471
472
473
474
475
476
477
478
479
480
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

481
482
483
484
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
485
    sliding_window = config_dict.get("sliding_window", -1)
486

487
488
489
    if max_input_tokens is not None and max_input_tokens <= sliding_window:
        sliding_window = -1

490
491
492
493
494
495
496
    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
497
        )
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
533
534
535
536
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
537
            speculator=speculator,
drbh's avatar
drbh committed
538
539
540
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
541

OlivierDehaene's avatar
OlivierDehaene committed
542
    if model_id.startswith("facebook/galactica"):
543
544
545
546
547
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
548
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
549
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
550
551
            dtype=dtype,
            trust_remote_code=trust_remote_code,
552
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
553
554
        )

555
    if (
556
557
        model_type == GPT_BIGCODE
        or model_type == GPT2
558
559
        and model_id.startswith("bigcode/")
    ):
560
        if FLASH_ATTENTION:
561
562
563
564
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
565
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
566
                speculator=speculator,
567
                dtype=dtype,
568
                trust_remote_code=trust_remote_code,
569
570
571
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
572
            )
573
574
575
576
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
577
        else:
578
579
580
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
581
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
582
                speculator=speculator,
583
                dtype=dtype,
584
585
                trust_remote_code=trust_remote_code,
            )
586

587
    if model_type == BLOOM:
588
589
590
591
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
592
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
593
            speculator=speculator,
594
595
            dtype=dtype,
            trust_remote_code=trust_remote_code,
596
            batch_class=BloomCausalLMBatch,
597
        )
598
    elif model_type == MPT:
599
600
601
602
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
603
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
604
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
605
606
            dtype=dtype,
            trust_remote_code=trust_remote_code,
607
            batch_class=CausalLMBatchKeysLast,
608
        )
609
    elif model_type == GPT2:
610
        if FLASH_ATTENTION:
611
            try:
612
613
614
615
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
616
617
618
619
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
620
                    lora_adapter_ids=lora_adapter_ids,
621
622
623
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
624
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
625
                return CausalLM.fallback(
626
627
628
629
630
631
632
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
633
634
635
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
636
            return CausalLM.fallback(
637
638
639
640
641
642
643
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
644
    elif model_type == GPT_NEOX:
645
        if FLASH_ATTENTION:
646
647
648
649
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

650
651
652
653
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
654
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
655
                speculator=speculator,
656
                dtype=dtype,
657
                trust_remote_code=trust_remote_code,
658
                lora_adapter_ids=lora_adapter_ids,
659
                config_class=GPTNeoXConfig,
660
661
            )
        elif sharded:
662
663
664
665
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
666
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
667
                speculator=speculator,
668
                dtype=dtype,
669
670
                trust_remote_code=trust_remote_code,
            )
671
        else:
672
            return CausalLM.fallback(
673
674
675
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
676
                speculator=speculator,
677
                dtype=dtype,
678
679
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
680

681
    elif model_type == PHI:
drbh's avatar
drbh committed
682
        if FLASH_ATTENTION:
683
684
685
686
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
687
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
688
                speculator=speculator,
drbh's avatar
drbh committed
689
690
                dtype=dtype,
                trust_remote_code=trust_remote_code,
691
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
692
693
            )
        else:
694
            return CausalLM.fallback(
drbh's avatar
drbh committed
695
696
697
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
698
                speculator=speculator,
drbh's avatar
drbh committed
699
700
701
702
703
704
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
705
706
707
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
708
        else:
709
710
711
712
713
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
714
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
715
                speculator=speculator,
drbh's avatar
drbh committed
716
717
718
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
719

720
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
721
        print(f">>> model_type: {model_type}")
722
        if FLASH_ATTENTION:
723
724
725
726
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
727
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
728
                speculator=speculator,
729
                dtype=dtype,
730
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
731
                lora_adapter_ids=lora_adapter_ids,
732
            )
733
734
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
735
        else:
736
            return CausalLM.fallback(
737
738
739
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
740
                speculator=speculator,
741
                dtype=dtype,
742
743
                trust_remote_code=trust_remote_code,
            )
744
    if model_type == GEMMA:
745
        if FLASH_ATTENTION:
746
747
748
749
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
750
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
751
                speculator=speculator,
752
                dtype=dtype,
753
754
                # Works better for these models
                default_dtype=torch.bfloat16,
755
                trust_remote_code=trust_remote_code,
756
                lora_adapter_ids=lora_adapter_ids,
757
758
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
759
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
760
        else:
761
            return CausalLM.fallback(
762
763
764
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
765
                speculator=speculator,
766
767
768
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
769
770
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
771
772
773
774
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
775
776
777
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
778
779
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
780
                trust_remote_code=trust_remote_code,
781
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
782
783
784
785
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
786
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
787
788
789
790
791
792
793
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
794

795
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
796
        if FLASH_ATTENTION:
797
798
799
800
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
801
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
802
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
803
804
                dtype=dtype,
                trust_remote_code=trust_remote_code,
805
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
806
807
808
809
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
810
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
811
812
813
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
814
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
815
816
817
818
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

819
    if model_type == DBRX:
820
        if FLASH_ATTENTION:
821
822
823
824
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
825
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
826
                speculator=speculator,
827
                dtype=dtype,
828
829
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
830
                trust_remote_code=trust_remote_code,
831
832
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
833
834
835
836
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
837
            return CausalLM.fallback(
838
839
840
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
841
                speculator=speculator,
842
843
844
845
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

846
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
847
848
        if sharded:
            if FLASH_ATTENTION:
849
                if config_dict.get("alibi", False):
850
                    raise NotImplementedError("sharded is not supported for this model")
851
852
853
854
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
855
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
856
                    speculator=speculator,
857
                    dtype=dtype,
858
859
860
861
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
862
                    trust_remote_code=trust_remote_code,
863
864
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
865
                )
866
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
867
        else:
868
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
869
870
871
872
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
873
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
874
                    speculator=speculator,
875
                    dtype=dtype,
876
877
878
879
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
880
                    trust_remote_code=trust_remote_code,
881
882
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
883
884
                )
            else:
885
                return CausalLM.fallback(
886
887
888
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
889
                    speculator=speculator,
890
                    dtype=dtype,
891
892
893
                    trust_remote_code=trust_remote_code,
                )

894
    if model_type == MISTRAL:
895
        if FLASH_ATTENTION:
896
897
898
899
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
900
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
901
                speculator=speculator,
902
903
                dtype=dtype,
                trust_remote_code=trust_remote_code,
904
                lora_adapter_ids=lora_adapter_ids,
905
            )
OlivierDehaene's avatar
OlivierDehaene committed
906
907
908
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
909
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
910
911
912
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
913
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
914
915
916
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
917

918
    if model_type == MIXTRAL:
919
        if FLASH_ATTENTION:
920
921
922
923
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
924
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
925
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
926
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
927
                trust_remote_code=trust_remote_code,
928
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
929
            )
OlivierDehaene's avatar
OlivierDehaene committed
930
931
932
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
933
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
934
935
936
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
937
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
938
939
940
941
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

942
    if model_type == STARCODER2:
943
        if FLASH_ATTENTION:
944
945
946
947
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
948
                quantize=quantize,
949
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
950
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
951
                trust_remote_code=trust_remote_code,
952
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
953
954
955
956
957
958
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
959
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
960
961
962
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
963
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
964
965
966
967
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

968
    if model_type == QWEN2:
969
        if FLASH_ATTENTION:
970
971
972
973
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
974
                quantize=quantize,
975
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
976
977
                dtype=dtype,
                trust_remote_code=trust_remote_code,
978
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
979
980
981
982
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
983
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
984
985
986
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
987
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
988
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
989
990
                trust_remote_code=trust_remote_code,
            )
991

992
    if model_type == OPT:
993
994
995
996
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
997
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
998
            speculator=speculator,
999
1000
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1001
        )
1002

1003
    if model_type == T5:
1004
1005
1006
1007
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1008
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1009
            speculator=speculator,
1010
            dtype=dtype,
1011
            trust_remote_code=trust_remote_code,
1012
1013
1014
1015
1016
1017
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1018
        )
1019
    if model_type == IDEFICS:
1020
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1021
1022
1023
1024
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1025
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1026
1027
1028
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1029
1030
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1031
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1032
        if FLASH_ATTENTION:
1033
1034
1035
1036
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1037
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1038
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1039
1040
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1041
1042
1043
1044
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1045
1046
1047
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1048
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1049
        if FLASH_ATTENTION:
1050
1051
1052
1053
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1054
1055
1056
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1057
1058
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1059
                trust_remote_code=trust_remote_code,
1060
1061
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1062
1063
1064
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1065

1066
    if model_type == LLAVA_NEXT:
1067
        if FLASH_ATTENTION:
1068
1069
1070
1071
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1072
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1073
                speculator=speculator,
1074
1075
1076
1077
1078
1079
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1080
    if sharded:
1081
        raise NotImplementedError("sharded is not supported for AutoModel")
1082
    if quantize == "gptq":
1083
        raise NotImplementedError(
1084
1085
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1086
    if quantize == "awq":
1087
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1088
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1089
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1090
    elif quantize == "eetq":
1091
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1092
1093
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1094
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1095
        return CausalLM.fallback(
1096
1097
1098
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1099
            speculator=speculator,
1100
1101
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1102
        )
1103
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1104
        return Seq2SeqLM.fallback(
1105
1106
1107
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1108
            speculator=speculator,
1109
1110
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1111
1112
        )

1113
    auto_map = config_dict.get("auto_map", None)
1114
1115
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1116
            return CausalLM.fallback(
1117
1118
1119
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1120
                speculator=speculator,
1121
                dtype=dtype,
1122
1123
                trust_remote_code=trust_remote_code,
            )
1124
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1125
            return Seq2SeqLM.fallback(
1126
1127
1128
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1129
                speculator=speculator,
1130
                dtype=dtype,
1131
1132
                trust_remote_code=trust_remote_code,
            )
1133
1134

    raise ValueError(f"Unsupported model type {model_type}")
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model