__init__.py 42.5 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
35
36
37
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
38

39
40
41
42
43
44
45
46
47
48

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


49
from text_generation_server.utils.import_utils import SYSTEM
50
from text_generation_server.utils.log import log_master
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
66
    "get_model_with_lora_adapters",
67
68
]

69
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
70

71
FLASH_ATTENTION = True
72

73
try:
74
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
75
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
76
77
78
79
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
80
81
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
91
    )
92
93
94
95
96
97
98
99
100
101
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
102
    )
drbh's avatar
drbh committed
103
    from text_generation_server.models.pali_gemma import (
104
        PaliGemmaBatch,
drbh's avatar
drbh committed
105
    )
106
107
108
109
110
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
111
    )
112
    from text_generation_server.models.idefics import IDEFICSSharded
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
138
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
139
except ImportError as e:
140
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
141
    SUPPORTS_WINDOWING = False
142
    FLASH_ATTENTION = False
143

144
if FLASH_ATTENTION:
145
    __all__.append(FlashCausalLM)
146
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
147

drbh's avatar
drbh committed
148
149
150
151
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
152
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
153
154
155
156
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
157

158

159
class ModelType(enum.Enum):
160
161
162
163
164
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
192
193
194
195
196
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
197
198
199
200
201
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
255
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


310
def get_model(
311
    model_id: str,
drbh's avatar
drbh committed
312
    lora_adapter_ids: Optional[List[str]],
313
314
315
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
316
    speculate: Optional[int],
317
    dtype: Optional[str],
318
    trust_remote_code: bool,
319
    max_input_tokens: int,
320
) -> Model:
321
    global FLASH_ATTENTION
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

340
    if dtype is None:
341
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
342
343
            # These quantizers only work with float16 params.
            dtype = torch.float16
344
        elif quantize == "fp8":
345
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
346

347
            if FBGEMM_DYN_AVAILABLE:
348
349
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
350
351
352
353
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
354
355
356
357
358
359
360
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
361
362
363
364
365
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
366
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
367
    if "medusa_num_heads" in config_dict:
368
369
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
370
371
372
373
374
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
375
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
376
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
377
                )
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
383
384
385
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
386
387
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
388
389
390
391
392
393
394
395
396
397
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
398
399
400
401
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
402
        else:
Nicolas Patry's avatar
Nicolas Patry committed
403
404
405
406
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
407

Nicolas Patry's avatar
Nicolas Patry committed
408
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
462
463
464
465
466
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
467
468
469
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
470

drbh's avatar
drbh committed
471
472
473
474
475
476
477
478
479
480
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

481
482
483
484
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
485
    sliding_window = config_dict.get("sliding_window", -1)
486
487
488
489
490
491
492
493

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
494
        )
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
530
531
532
533
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
534
            speculator=speculator,
drbh's avatar
drbh committed
535
536
537
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
538

OlivierDehaene's avatar
OlivierDehaene committed
539
    if model_id.startswith("facebook/galactica"):
540
541
542
543
544
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
545
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
546
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
547
548
            dtype=dtype,
            trust_remote_code=trust_remote_code,
549
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
550
551
        )

552
    if (
553
554
        model_type == GPT_BIGCODE
        or model_type == GPT2
555
556
        and model_id.startswith("bigcode/")
    ):
557
        if FLASH_ATTENTION:
558
559
560
561
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
562
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
563
                speculator=speculator,
564
                dtype=dtype,
565
                trust_remote_code=trust_remote_code,
566
567
568
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
569
            )
570
571
572
573
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
574
        else:
575
576
577
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
578
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
579
                speculator=speculator,
580
                dtype=dtype,
581
582
                trust_remote_code=trust_remote_code,
            )
583

584
    if model_type == BLOOM:
585
586
587
588
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
589
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
590
            speculator=speculator,
591
592
            dtype=dtype,
            trust_remote_code=trust_remote_code,
593
            batch_class=BloomCausalLMBatch,
594
        )
595
    elif model_type == MPT:
596
597
598
599
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
600
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
601
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
602
603
            dtype=dtype,
            trust_remote_code=trust_remote_code,
604
            batch_class=CausalLMBatchKeysLast,
605
        )
606
    elif model_type == GPT2:
607
        if FLASH_ATTENTION:
608
            try:
609
610
611
612
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
613
614
615
616
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
617
                    lora_adapter_ids=lora_adapter_ids,
618
619
620
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
621
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
622
                return CausalLM.fallback(
623
624
625
626
627
628
629
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
630
631
632
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
633
            return CausalLM.fallback(
634
635
636
637
638
639
640
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
641
    elif model_type == GPT_NEOX:
642
        if FLASH_ATTENTION:
643
644
645
646
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

647
648
649
650
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
651
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
652
                speculator=speculator,
653
                dtype=dtype,
654
                trust_remote_code=trust_remote_code,
655
                lora_adapter_ids=lora_adapter_ids,
656
                config_class=GPTNeoXConfig,
657
658
            )
        elif sharded:
659
660
661
662
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
663
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
                dtype=dtype,
666
667
                trust_remote_code=trust_remote_code,
            )
668
        else:
669
            return CausalLM.fallback(
670
671
672
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
673
                speculator=speculator,
674
                dtype=dtype,
675
676
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
677

678
    elif model_type == PHI:
drbh's avatar
drbh committed
679
        if FLASH_ATTENTION:
680
681
682
683
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
684
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
685
                speculator=speculator,
drbh's avatar
drbh committed
686
687
                dtype=dtype,
                trust_remote_code=trust_remote_code,
688
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
689
690
            )
        else:
691
            return CausalLM.fallback(
drbh's avatar
drbh committed
692
693
694
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
695
                speculator=speculator,
drbh's avatar
drbh committed
696
697
698
699
700
701
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
702
703
704
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
705
        else:
706
707
708
709
710
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
711
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
712
                speculator=speculator,
drbh's avatar
drbh committed
713
714
715
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
716

717
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
718
        print(f">>> model_type: {model_type}")
719
        if FLASH_ATTENTION:
720
721
722
723
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
724
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
725
                speculator=speculator,
726
                dtype=dtype,
727
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
728
                lora_adapter_ids=lora_adapter_ids,
729
            )
730
731
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
732
        else:
733
            return CausalLM.fallback(
734
735
736
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
737
                speculator=speculator,
738
                dtype=dtype,
739
740
                trust_remote_code=trust_remote_code,
            )
741
    if model_type == GEMMA:
742
        if FLASH_ATTENTION:
743
744
745
746
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
747
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
748
                speculator=speculator,
749
                dtype=dtype,
750
751
                # Works better for these models
                default_dtype=torch.bfloat16,
752
                trust_remote_code=trust_remote_code,
753
                lora_adapter_ids=lora_adapter_ids,
754
755
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
756
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
757
        else:
758
            return CausalLM.fallback(
759
760
761
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
762
                speculator=speculator,
763
764
765
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
766
767
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
768
769
770
771
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
772
773
774
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
775
776
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
777
                trust_remote_code=trust_remote_code,
778
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
779
780
781
782
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
783
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
784
785
786
787
788
789
790
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
791

792
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
793
        if FLASH_ATTENTION:
794
795
796
797
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
798
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
799
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
800
801
                dtype=dtype,
                trust_remote_code=trust_remote_code,
802
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
803
804
805
806
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
807
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
808
809
810
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
811
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
812
813
814
815
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

816
    if model_type == DBRX:
817
        if FLASH_ATTENTION:
818
819
820
821
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
822
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
823
                speculator=speculator,
824
                dtype=dtype,
825
826
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
827
                trust_remote_code=trust_remote_code,
828
829
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
830
831
832
833
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
834
            return CausalLM.fallback(
835
836
837
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
838
                speculator=speculator,
839
840
841
842
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

843
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
844
845
        if sharded:
            if FLASH_ATTENTION:
846
                if config_dict.get("alibi", False):
847
                    raise NotImplementedError("sharded is not supported for this model")
848
849
850
851
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
852
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
853
                    speculator=speculator,
854
                    dtype=dtype,
855
856
857
858
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
859
                    trust_remote_code=trust_remote_code,
860
861
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
862
                )
863
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
864
        else:
865
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
866
867
868
869
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
870
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
871
                    speculator=speculator,
872
                    dtype=dtype,
873
874
875
876
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
877
                    trust_remote_code=trust_remote_code,
878
879
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
880
881
                )
            else:
882
                return CausalLM.fallback(
883
884
885
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
886
                    speculator=speculator,
887
                    dtype=dtype,
888
889
890
                    trust_remote_code=trust_remote_code,
                )

891
    if model_type == MISTRAL:
892
        if FLASH_ATTENTION:
893
894
895
896
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
897
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
898
                speculator=speculator,
899
900
                dtype=dtype,
                trust_remote_code=trust_remote_code,
901
                lora_adapter_ids=lora_adapter_ids,
902
            )
OlivierDehaene's avatar
OlivierDehaene committed
903
904
905
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
906
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
907
908
909
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
910
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
911
912
913
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
914

915
    if model_type == MIXTRAL:
916
        if FLASH_ATTENTION:
917
918
919
920
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
921
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
922
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
923
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
924
                trust_remote_code=trust_remote_code,
925
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
926
            )
OlivierDehaene's avatar
OlivierDehaene committed
927
928
929
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
930
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
931
932
933
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
934
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
935
936
937
938
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

939
    if model_type == STARCODER2:
940
        if FLASH_ATTENTION:
941
942
943
944
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
945
                quantize=quantize,
946
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
947
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
948
                trust_remote_code=trust_remote_code,
949
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
950
951
952
953
954
955
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
956
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
957
958
959
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
960
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
961
962
963
964
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

965
    if model_type == QWEN2:
966
        if FLASH_ATTENTION:
967
968
969
970
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
971
                quantize=quantize,
972
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
973
974
                dtype=dtype,
                trust_remote_code=trust_remote_code,
975
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
976
977
978
979
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
980
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
981
982
983
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
984
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
985
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
986
987
                trust_remote_code=trust_remote_code,
            )
988

989
    if model_type == OPT:
990
991
992
993
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
994
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
995
            speculator=speculator,
996
997
            dtype=dtype,
            trust_remote_code=trust_remote_code,
998
        )
999

1000
    if model_type == T5:
1001
1002
1003
1004
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1005
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1006
            speculator=speculator,
1007
            dtype=dtype,
1008
            trust_remote_code=trust_remote_code,
1009
1010
1011
1012
1013
1014
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1015
        )
1016
    if model_type == IDEFICS:
1017
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1018
1019
1020
1021
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1022
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1023
1024
1025
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1026
1027
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1028
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1029
        if FLASH_ATTENTION:
1030
1031
1032
1033
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1034
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1035
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1036
1037
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1038
1039
1040
1041
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1042
1043
1044
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1045
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1046
        if FLASH_ATTENTION:
1047
1048
1049
1050
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1051
1052
1053
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1054
1055
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1056
                trust_remote_code=trust_remote_code,
1057
1058
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1059
1060
1061
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1062

1063
    if model_type == LLAVA_NEXT:
1064
        if FLASH_ATTENTION:
1065
1066
1067
1068
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1069
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1070
                speculator=speculator,
1071
1072
1073
1074
1075
1076
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1077
    if sharded:
1078
        raise NotImplementedError("sharded is not supported for AutoModel")
1079
    if quantize == "gptq":
1080
        raise NotImplementedError(
1081
1082
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1083
    if quantize == "awq":
1084
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1085
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1086
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1087
    elif quantize == "eetq":
1088
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1089
1090
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1091
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1092
        return CausalLM.fallback(
1093
1094
1095
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1096
            speculator=speculator,
1097
1098
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1099
        )
1100
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1101
        return Seq2SeqLM.fallback(
1102
1103
1104
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1105
            speculator=speculator,
1106
1107
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1108
1109
        )

1110
    auto_map = config_dict.get("auto_map", None)
1111
1112
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1113
            return CausalLM.fallback(
1114
1115
1116
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1117
                speculator=speculator,
1118
                dtype=dtype,
1119
1120
                trust_remote_code=trust_remote_code,
            )
1121
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1122
            return Seq2SeqLM.fallback(
1123
1124
1125
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1126
                speculator=speculator,
1127
                dtype=dtype,
1128
1129
                trust_remote_code=trust_remote_code,
            )
1130
1131

    raise ValueError(f"Unsupported model type {model_type}")
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model