__init__.py 44.3 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
35
36
37
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
38

39
40
41
42
43
44
45
46
47
48

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


49
from text_generation_server.utils.import_utils import SYSTEM
50
from text_generation_server.utils.log import log_master
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
66
    "get_model_with_lora_adapters",
67
68
]

69
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
70

71
FLASH_ATTENTION = True
72

73
try:
74
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
75
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
76
77
78
79
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
80
81
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
91
    )
92
93
94
95
96
97
98
99
100
101
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
102
    )
drbh's avatar
drbh committed
103
    from text_generation_server.models.pali_gemma import (
104
        PaliGemmaBatch,
drbh's avatar
drbh committed
105
    )
106
107
108
109
110
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
111
    )
112
    from text_generation_server.models.idefics import IDEFICSSharded
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
135
136
137
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
138
139
140
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
141
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
142
except ImportError as e:
143
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
144
    SUPPORTS_WINDOWING = False
145
    FLASH_ATTENTION = False
146

147
if FLASH_ATTENTION:
148
    __all__.append(FlashCausalLM)
149
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
150

drbh's avatar
drbh committed
151
152
153
154
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
155
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
156
157
158
159
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
160

161

162
class ModelType(enum.Enum):
163
164
165
166
167
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
183
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
184
185
186
187
188
189
190
191
192
193
194
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
195
196
197
198
199
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
200
201
202
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
203
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
204
    }
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
223
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
258
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
300
301
302
303
304
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
305
306
307
308
309
310
311
312
313
314
315
316
317
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


318
def get_model(
319
    model_id: str,
drbh's avatar
drbh committed
320
    lora_adapter_ids: Optional[List[str]],
321
322
323
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
324
    speculate: Optional[int],
325
    dtype: Optional[str],
326
    trust_remote_code: bool,
327
    max_input_tokens: int,
328
) -> Model:
329
    global FLASH_ATTENTION
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

348
    if dtype is None:
349
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
350
351
            # These quantizers only work with float16 params.
            dtype = torch.float16
352
        elif quantize == "fp8":
353
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
354

355
            if FBGEMM_DYN_AVAILABLE:
356
357
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
358
359
360
361
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
362
363
364
365
366
367
368
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
374
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
375
    if "medusa_num_heads" in config_dict:
376
377
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
383
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
384
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
385
                )
Nicolas Patry's avatar
Nicolas Patry committed
386
387
388
389
390
391
392
393
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
394
395
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
396
397
398
399
400
401
402
403
404
405
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
409
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
410
        else:
Nicolas Patry's avatar
Nicolas Patry committed
411
412
413
414
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
415

Nicolas Patry's avatar
Nicolas Patry committed
416
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
470
471
472
473
474
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
475
476
477
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
478

drbh's avatar
drbh committed
479
480
481
482
483
484
485
486
487
488
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

489
490
491
492
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
493
494
495
496
497
498

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
499

500
501
502
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
503
    )
504
505
506
507
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
543
544
545
546
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
547
            speculator=speculator,
drbh's avatar
drbh committed
548
549
550
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
551

OlivierDehaene's avatar
OlivierDehaene committed
552
    if model_id.startswith("facebook/galactica"):
553
554
555
556
557
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
558
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
559
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
560
561
            dtype=dtype,
            trust_remote_code=trust_remote_code,
562
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
563
564
        )

565
    if (
566
567
        model_type == GPT_BIGCODE
        or model_type == GPT2
568
569
        and model_id.startswith("bigcode/")
    ):
570
        if FLASH_ATTENTION:
571
572
573
574
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
575
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
576
                speculator=speculator,
577
                dtype=dtype,
578
                trust_remote_code=trust_remote_code,
579
580
581
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
582
            )
583
584
585
586
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
587
        else:
588
589
590
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
591
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
592
                speculator=speculator,
593
                dtype=dtype,
594
595
                trust_remote_code=trust_remote_code,
            )
596

597
    if model_type == BLOOM:
598
599
600
601
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
602
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
603
            speculator=speculator,
604
605
            dtype=dtype,
            trust_remote_code=trust_remote_code,
606
            batch_class=BloomCausalLMBatch,
607
        )
608
    elif model_type == MPT:
609
610
611
612
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
613
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
614
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
615
616
            dtype=dtype,
            trust_remote_code=trust_remote_code,
617
            batch_class=CausalLMBatchKeysLast,
618
        )
619
    elif model_type == GPT2:
620
        if FLASH_ATTENTION:
621
            try:
622
623
624
625
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
626
627
628
629
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
630
                    lora_adapter_ids=lora_adapter_ids,
631
632
633
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
634
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
635
                return CausalLM.fallback(
636
637
638
639
640
641
642
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
643
644
645
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
646
            return CausalLM.fallback(
647
648
649
650
651
652
653
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
689
    elif model_type == GPT_NEOX:
690
        if FLASH_ATTENTION:
691
692
693
694
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

695
696
697
698
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
699
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
700
                speculator=speculator,
701
                dtype=dtype,
702
                trust_remote_code=trust_remote_code,
703
                lora_adapter_ids=lora_adapter_ids,
704
                config_class=GPTNeoXConfig,
705
706
            )
        elif sharded:
707
708
709
710
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
711
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
712
                speculator=speculator,
713
                dtype=dtype,
714
715
                trust_remote_code=trust_remote_code,
            )
716
        else:
717
            return CausalLM.fallback(
718
719
720
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
721
                speculator=speculator,
722
                dtype=dtype,
723
724
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
725

726
    elif model_type == PHI:
drbh's avatar
drbh committed
727
        if FLASH_ATTENTION:
728
729
730
731
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
732
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
733
                speculator=speculator,
drbh's avatar
drbh committed
734
735
                dtype=dtype,
                trust_remote_code=trust_remote_code,
736
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
737
738
            )
        else:
739
            return CausalLM.fallback(
drbh's avatar
drbh committed
740
741
742
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
743
                speculator=speculator,
drbh's avatar
drbh committed
744
745
746
747
748
749
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
750
751
752
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
753
        else:
754
755
756
757
758
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
759
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
760
                speculator=speculator,
drbh's avatar
drbh committed
761
762
763
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
764

765
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
766
        print(f">>> model_type: {model_type}")
767
        if FLASH_ATTENTION:
768
769
770
771
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
772
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
773
                speculator=speculator,
774
                dtype=dtype,
775
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
776
                lora_adapter_ids=lora_adapter_ids,
777
            )
778
779
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
780
        else:
781
            return CausalLM.fallback(
782
783
784
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
785
                speculator=speculator,
786
                dtype=dtype,
787
788
                trust_remote_code=trust_remote_code,
            )
789
    if model_type == GEMMA:
790
        if FLASH_ATTENTION:
791
792
793
794
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
795
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
796
                speculator=speculator,
797
                dtype=dtype,
798
799
                # Works better for these models
                default_dtype=torch.bfloat16,
800
                trust_remote_code=trust_remote_code,
801
                lora_adapter_ids=lora_adapter_ids,
802
803
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
804
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
805
        else:
806
            return CausalLM.fallback(
807
808
809
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
810
                speculator=speculator,
811
812
813
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
814
815
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
816
817
818
819
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
820
821
822
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
823
824
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
825
                trust_remote_code=trust_remote_code,
826
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
827
828
829
830
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
831
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
832
833
834
835
836
837
838
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
839

840
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
841
        if FLASH_ATTENTION:
842
843
844
845
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
846
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
847
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
848
849
                dtype=dtype,
                trust_remote_code=trust_remote_code,
850
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
851
852
853
854
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
855
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
856
857
858
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
859
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
860
861
862
863
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

864
    if model_type == DBRX:
865
        if FLASH_ATTENTION:
866
867
868
869
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
870
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
871
                speculator=speculator,
872
                dtype=dtype,
873
874
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
875
                trust_remote_code=trust_remote_code,
876
877
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
878
879
880
881
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
882
            return CausalLM.fallback(
883
884
885
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
886
                speculator=speculator,
887
888
889
890
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

891
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
892
893
        if sharded:
            if FLASH_ATTENTION:
894
                if config_dict.get("alibi", False):
895
                    raise NotImplementedError("sharded is not supported for this model")
896
897
898
899
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
900
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
901
                    speculator=speculator,
902
                    dtype=dtype,
903
904
905
906
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
907
                    trust_remote_code=trust_remote_code,
908
909
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
910
                )
911
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
912
        else:
913
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
914
915
916
917
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
918
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
919
                    speculator=speculator,
920
                    dtype=dtype,
921
922
923
924
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
925
                    trust_remote_code=trust_remote_code,
926
927
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
928
929
                )
            else:
930
                return CausalLM.fallback(
931
932
933
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
934
                    speculator=speculator,
935
                    dtype=dtype,
936
937
938
                    trust_remote_code=trust_remote_code,
                )

939
    if model_type == MISTRAL:
940
        if FLASH_ATTENTION:
941
942
943
944
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
945
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
946
                speculator=speculator,
947
948
                dtype=dtype,
                trust_remote_code=trust_remote_code,
949
                lora_adapter_ids=lora_adapter_ids,
950
            )
OlivierDehaene's avatar
OlivierDehaene committed
951
952
953
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
954
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
955
956
957
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
958
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
959
960
961
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
962

963
    if model_type == MIXTRAL:
964
        if FLASH_ATTENTION:
965
966
967
968
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
969
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
970
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
971
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
972
                trust_remote_code=trust_remote_code,
973
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
974
            )
OlivierDehaene's avatar
OlivierDehaene committed
975
976
977
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
978
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
979
980
981
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
982
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
983
984
985
986
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

987
    if model_type == STARCODER2:
988
        if FLASH_ATTENTION:
989
990
991
992
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
993
                quantize=quantize,
994
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
995
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
996
                trust_remote_code=trust_remote_code,
997
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
998
999
1000
1001
1002
1003
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1004
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1005
1006
1007
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1008
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1009
1010
1011
1012
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1013
    if model_type == QWEN2:
1014
        if FLASH_ATTENTION:
1015
1016
1017
1018
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1019
                quantize=quantize,
1020
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1021
1022
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1023
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1024
1025
1026
1027
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1028
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1029
1030
1031
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1032
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1033
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1034
1035
                trust_remote_code=trust_remote_code,
            )
1036

1037
    if model_type == OPT:
1038
1039
1040
1041
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1042
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1043
            speculator=speculator,
1044
1045
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1046
        )
1047

1048
    if model_type == T5:
1049
1050
1051
1052
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1053
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1054
            speculator=speculator,
1055
            dtype=dtype,
1056
            trust_remote_code=trust_remote_code,
1057
1058
1059
1060
1061
1062
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1063
        )
1064
    if model_type == IDEFICS:
1065
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1066
1067
1068
1069
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1070
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1071
1072
1073
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1074
1075
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1076
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1077
        if FLASH_ATTENTION:
1078
1079
1080
1081
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1082
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1083
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1084
1085
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1086
1087
1088
1089
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1090
1091
1092
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1093
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1094
        if FLASH_ATTENTION:
1095
1096
1097
1098
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1099
1100
1101
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1102
1103
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1104
                trust_remote_code=trust_remote_code,
1105
1106
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1107
1108
1109
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1110

1111
    if model_type == LLAVA_NEXT:
1112
        if FLASH_ATTENTION:
1113
1114
1115
1116
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1117
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1118
                speculator=speculator,
1119
1120
1121
1122
1123
1124
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1125
    if sharded:
1126
        raise NotImplementedError("sharded is not supported for AutoModel")
1127
    if quantize == "gptq":
1128
        raise NotImplementedError(
1129
1130
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1131
    if quantize == "awq":
1132
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1133
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1134
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1135
    elif quantize == "eetq":
1136
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1137
1138
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1139
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1140
        return CausalLM.fallback(
1141
1142
1143
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1144
            speculator=speculator,
1145
1146
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1147
        )
1148
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1149
        return Seq2SeqLM.fallback(
1150
1151
1152
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1153
            speculator=speculator,
1154
1155
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1156
1157
        )

1158
    auto_map = config_dict.get("auto_map", None)
1159
1160
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1161
            return CausalLM.fallback(
1162
1163
1164
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1165
                speculator=speculator,
1166
                dtype=dtype,
1167
1168
                trust_remote_code=trust_remote_code,
            )
1169
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1170
            return Seq2SeqLM.fallback(
1171
1172
1173
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1174
                speculator=speculator,
1175
                dtype=dtype,
1176
1177
                trust_remote_code=trust_remote_code,
            )
1178
1179

    raise ValueError(f"Unsupported model type {model_type}")
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model