__init__.py 42.4 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
9
from typing import Optional, List, Dict
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
19
from text_generation_server.models.bloom import BloomCausalLMBatch
20
21
22
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
23
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
24
25
26
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
35

36
37
38
39
40
41
42
43
44
45

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


46
from text_generation_server.utils.import_utils import SYSTEM
47
from text_generation_server.utils.log import log_master
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
63
    "get_model_with_lora_adapters",
64
65
]

66
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
67

68
FLASH_ATTENTION = True
69

70
try:
71
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
72
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
73
74
75
76
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
77
78
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
79
    )
80
81
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
88
    )
89
90
91
92
93
94
95
96
97
98
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
99
    )
drbh's avatar
drbh committed
100
    from text_generation_server.models.pali_gemma import (
101
        PaliGemmaBatch,
drbh's avatar
drbh committed
102
    )
103
104
105
106
107
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
108
    )
109
    from text_generation_server.models.idefics import IDEFICSSharded
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
135
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
136
except ImportError as e:
137
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
138
    SUPPORTS_WINDOWING = False
139
    FLASH_ATTENTION = False
140

141
if FLASH_ATTENTION:
142
    __all__.append(FlashCausalLM)
143
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
144

drbh's avatar
drbh committed
145
146
147
148
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
149
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
150
151
152
153
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
154

155

156
class ModelType(enum.Enum):
157
158
159
160
161
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
189
190
191
192
193
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
194
195
196
197
198
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
252
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


307
def get_model(
308
    model_id: str,
drbh's avatar
drbh committed
309
    lora_adapter_ids: Optional[List[str]],
310
311
312
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
313
    speculate: Optional[int],
314
    dtype: Optional[str],
315
    trust_remote_code: bool,
316
    max_input_tokens: int,
317
) -> Model:
318
    global FLASH_ATTENTION
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

337
    if dtype is None:
338
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
339
340
            # These quantizers only work with float16 params.
            dtype = torch.float16
341
        elif quantize == "fp8":
342
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
343

344
            if FBGEMM_DYN_AVAILABLE:
345
346
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
347
348
349
350
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
351
352
353
354
355
356
357
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
358
359
360
361
362
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
363
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
364
    if "medusa_num_heads" in config_dict:
365
366
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
367
368
369
370
371
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
372
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
373
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
374
                )
Nicolas Patry's avatar
Nicolas Patry committed
375
376
377
378
379
380
381
382
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
383
384
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
385
386
387
388
389
390
391
392
393
394
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
395
396
397
398
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
399
        else:
Nicolas Patry's avatar
Nicolas Patry committed
400
401
402
403
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
404

Nicolas Patry's avatar
Nicolas Patry committed
405
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
459
460
461
462
463
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
464
465
466
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
467

drbh's avatar
drbh committed
468
469
470
471
472
473
474
475
476
477
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

478
479
480
481
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
482
    sliding_window = config_dict.get("sliding_window", -1)
483
484
485
486
487
488
489
490

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
491
        )
492

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
527
528
529
530
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
531
            speculator=speculator,
drbh's avatar
drbh committed
532
533
534
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
535

OlivierDehaene's avatar
OlivierDehaene committed
536
    if model_id.startswith("facebook/galactica"):
537
538
539
540
541
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
542
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
543
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
544
545
            dtype=dtype,
            trust_remote_code=trust_remote_code,
546
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
547
548
        )

549
    if (
550
551
        model_type == GPT_BIGCODE
        or model_type == GPT2
552
553
        and model_id.startswith("bigcode/")
    ):
554
        if FLASH_ATTENTION:
555
556
557
558
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
559
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
560
                speculator=speculator,
561
                dtype=dtype,
562
                trust_remote_code=trust_remote_code,
563
564
565
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
566
            )
567
568
569
570
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
571
        else:
572
573
574
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
575
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
576
                speculator=speculator,
577
                dtype=dtype,
578
579
                trust_remote_code=trust_remote_code,
            )
580

581
    if model_type == BLOOM:
582
583
584
585
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
586
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
587
            speculator=speculator,
588
589
            dtype=dtype,
            trust_remote_code=trust_remote_code,
590
            batch_class=BloomCausalLMBatch,
591
        )
592
    elif model_type == MPT:
593
594
595
596
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
597
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
598
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
599
600
            dtype=dtype,
            trust_remote_code=trust_remote_code,
601
            batch_class=CausalLMBatchKeysLast,
602
        )
603
    elif model_type == GPT2:
604
        if FLASH_ATTENTION:
605
            try:
606
607
608
609
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
610
611
612
613
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
614
                    lora_adapter_ids=lora_adapter_ids,
615
616
617
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
618
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
619
                return CausalLM.fallback(
620
621
622
623
624
625
626
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
627
628
629
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
630
            return CausalLM.fallback(
631
632
633
634
635
636
637
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
638
    elif model_type == GPT_NEOX:
639
        if FLASH_ATTENTION:
640
641
642
643
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

644
645
646
647
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
648
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
649
                speculator=speculator,
650
                dtype=dtype,
651
                trust_remote_code=trust_remote_code,
652
                lora_adapter_ids=lora_adapter_ids,
653
                config_class=GPTNeoXConfig,
654
655
            )
        elif sharded:
656
657
658
659
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
660
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
661
                speculator=speculator,
662
                dtype=dtype,
663
664
                trust_remote_code=trust_remote_code,
            )
665
        else:
666
            return CausalLM.fallback(
667
668
669
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
670
                speculator=speculator,
671
                dtype=dtype,
672
673
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
674

675
    elif model_type == PHI:
drbh's avatar
drbh committed
676
        if FLASH_ATTENTION:
677
678
679
680
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
681
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
682
                speculator=speculator,
drbh's avatar
drbh committed
683
684
                dtype=dtype,
                trust_remote_code=trust_remote_code,
685
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
686
687
            )
        else:
688
            return CausalLM.fallback(
drbh's avatar
drbh committed
689
690
691
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
692
                speculator=speculator,
drbh's avatar
drbh committed
693
694
695
696
697
698
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
699
700
701
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
702
        else:
703
704
705
706
707
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
708
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
709
                speculator=speculator,
drbh's avatar
drbh committed
710
711
712
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
713

714
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
715
        if FLASH_ATTENTION:
716
717
718
719
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
720
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
721
                speculator=speculator,
722
                dtype=dtype,
723
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
724
                lora_adapter_ids=lora_adapter_ids,
725
            )
726
727
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
728
        else:
729
            return CausalLM.fallback(
730
731
732
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
733
                speculator=speculator,
734
                dtype=dtype,
735
736
                trust_remote_code=trust_remote_code,
            )
737
    if model_type == GEMMA:
738
        if FLASH_ATTENTION:
739
740
741
742
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
743
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
744
                speculator=speculator,
745
                dtype=dtype,
746
747
                # Works better for these models
                default_dtype=torch.bfloat16,
748
                trust_remote_code=trust_remote_code,
749
                lora_adapter_ids=lora_adapter_ids,
750
751
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
752
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
753
        else:
754
            return CausalLM.fallback(
755
756
757
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
758
                speculator=speculator,
759
760
761
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
762
763
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
764
765
766
767
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
768
769
770
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
771
772
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
773
                trust_remote_code=trust_remote_code,
774
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
775
776
777
778
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
779
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
780
781
782
783
784
785
786
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
787

788
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
789
        if FLASH_ATTENTION:
790
791
792
793
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
794
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
795
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
796
797
                dtype=dtype,
                trust_remote_code=trust_remote_code,
798
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
799
800
801
802
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
803
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
804
805
806
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
807
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
808
809
810
811
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

812
    if model_type == DBRX:
813
        if FLASH_ATTENTION:
814
815
816
817
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
818
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
819
                speculator=speculator,
820
                dtype=dtype,
821
822
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
823
                trust_remote_code=trust_remote_code,
824
825
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
826
827
828
829
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
830
            return CausalLM.fallback(
831
832
833
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
834
                speculator=speculator,
835
836
837
838
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

839
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
840
841
        if sharded:
            if FLASH_ATTENTION:
842
                if config_dict.get("alibi", False):
843
                    raise NotImplementedError("sharded is not supported for this model")
844
845
846
847
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
848
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
849
                    speculator=speculator,
850
                    dtype=dtype,
851
852
853
854
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
855
                    trust_remote_code=trust_remote_code,
856
857
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
858
                )
859
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
860
        else:
861
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
862
863
864
865
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
866
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
867
                    speculator=speculator,
868
                    dtype=dtype,
869
870
871
872
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
873
                    trust_remote_code=trust_remote_code,
874
875
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
876
877
                )
            else:
878
                return CausalLM.fallback(
879
880
881
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
882
                    speculator=speculator,
883
                    dtype=dtype,
884
885
886
                    trust_remote_code=trust_remote_code,
                )

887
    if model_type == MISTRAL:
888
        if FLASH_ATTENTION:
889
890
891
892
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
893
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
894
                speculator=speculator,
895
896
                dtype=dtype,
                trust_remote_code=trust_remote_code,
897
                lora_adapter_ids=lora_adapter_ids,
898
            )
OlivierDehaene's avatar
OlivierDehaene committed
899
900
901
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
902
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
903
904
905
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
906
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
907
908
909
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
910

911
    if model_type == MIXTRAL:
912
        if FLASH_ATTENTION:
913
914
915
916
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
917
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
918
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
919
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
920
                trust_remote_code=trust_remote_code,
921
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
922
            )
OlivierDehaene's avatar
OlivierDehaene committed
923
924
925
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
926
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
927
928
929
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
930
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
931
932
933
934
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

935
    if model_type == STARCODER2:
936
        if FLASH_ATTENTION:
937
938
939
940
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
941
                quantize=quantize,
942
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
943
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
944
                trust_remote_code=trust_remote_code,
945
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
946
947
948
949
950
951
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
952
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
953
954
955
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
956
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
957
958
959
960
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

961
    if model_type == QWEN2:
962
        if FLASH_ATTENTION:
963
964
965
966
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
967
                quantize=quantize,
968
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
969
970
                dtype=dtype,
                trust_remote_code=trust_remote_code,
971
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
972
973
974
975
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
976
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
977
978
979
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
980
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
981
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
982
983
                trust_remote_code=trust_remote_code,
            )
984

985
    if model_type == OPT:
986
987
988
989
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
990
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
991
            speculator=speculator,
992
993
            dtype=dtype,
            trust_remote_code=trust_remote_code,
994
        )
995

996
    if model_type == T5:
997
998
999
1000
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1001
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1002
            speculator=speculator,
1003
            dtype=dtype,
1004
            trust_remote_code=trust_remote_code,
1005
1006
1007
1008
1009
1010
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1011
        )
1012
    if model_type == IDEFICS:
1013
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1014
1015
1016
1017
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1018
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1019
1020
1021
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1022
1023
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1024
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1025
        if FLASH_ATTENTION:
1026
1027
1028
1029
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1030
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1031
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1032
1033
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1034
1035
1036
1037
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1038
1039
1040
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1041
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1042
        if FLASH_ATTENTION:
1043
1044
1045
1046
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1047
1048
1049
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1050
1051
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1052
                trust_remote_code=trust_remote_code,
1053
1054
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1055
1056
1057
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1058

1059
    if model_type == LLAVA_NEXT:
1060
        if FLASH_ATTENTION:
1061
1062
1063
1064
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1065
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1066
                speculator=speculator,
1067
1068
1069
1070
1071
1072
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1073
    if sharded:
1074
        raise NotImplementedError("sharded is not supported for AutoModel")
1075
    if quantize == "gptq":
1076
        raise NotImplementedError(
1077
1078
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1079
    if quantize == "awq":
1080
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1081
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1082
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1083
    elif quantize == "eetq":
1084
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1085
1086
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1087
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1088
        return CausalLM.fallback(
1089
1090
1091
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1092
            speculator=speculator,
1093
1094
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1095
        )
1096
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1097
        return Seq2SeqLM.fallback(
1098
1099
1100
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1101
            speculator=speculator,
1102
1103
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1104
1105
        )

1106
    auto_map = config_dict.get("auto_map", None)
1107
1108
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1109
            return CausalLM.fallback(
1110
1111
1112
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1113
                speculator=speculator,
1114
                dtype=dtype,
1115
1116
                trust_remote_code=trust_remote_code,
            )
1117
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1118
            return Seq2SeqLM.fallback(
1119
1120
1121
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1122
                speculator=speculator,
1123
                dtype=dtype,
1124
1125
                trust_remote_code=trust_remote_code,
            )
1126
1127

    raise ValueError(f"Unsupported model type {model_type}")
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model