__init__.py 44.3 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
35
36
37
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
38

39
40
41
42
43
44
45
46
47
48

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


49
from text_generation_server.utils.import_utils import SYSTEM
50
from text_generation_server.utils.log import log_master
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
66
    "get_model_with_lora_adapters",
67
68
]

69
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
70

71
FLASH_ATTENTION = True
72

73
try:
74
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
75
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
76
77
78
79
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
80
81
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
91
    )
92
93
94
95
96
97
98
99
100
101
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
102
    )
drbh's avatar
drbh committed
103
    from text_generation_server.models.pali_gemma import (
104
        PaliGemmaBatch,
drbh's avatar
drbh committed
105
    )
106
107
108
109
110
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
111
    )
112
    from text_generation_server.models.idefics import IDEFICSSharded
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
135
136
137
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
138
139
140
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
141
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
142
except ImportError as e:
143
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
144
    SUPPORTS_WINDOWING = False
145
    FLASH_ATTENTION = False
146

147
if FLASH_ATTENTION:
148
    __all__.append(FlashCausalLM)
149
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
150

drbh's avatar
drbh committed
151
152
153
154
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
155
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
156
157
158
159
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
160

161

162
class ModelType(enum.Enum):
163
164
165
166
167
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
183
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
184
185
186
187
188
189
190
191
192
193
194
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
195
196
197
198
199
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
200
201
202
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
203
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
204
    }
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
223
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
258
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
300
301
302
303
304
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
305
306
307
308
309
310
311
312
313
314
315
316
317
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


318
def get_model(
319
    model_id: str,
drbh's avatar
drbh committed
320
    lora_adapter_ids: Optional[List[str]],
321
322
323
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
324
    speculate: Optional[int],
325
    dtype: Optional[str],
326
    trust_remote_code: bool,
327
    max_input_tokens: int,
328
) -> Model:
329
    global FLASH_ATTENTION
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

348
    if dtype is None:
349
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
350
351
            # These quantizers only work with float16 params.
            dtype = torch.float16
352
        elif quantize == "fp8":
353
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
354

355
            if FBGEMM_DYN_AVAILABLE:
356
357
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
358
359
360
361
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
362
363
364
365
366
367
368
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
374
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
375
    if "medusa_num_heads" in config_dict:
376
377
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
383
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
384
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
385
                )
Nicolas Patry's avatar
Nicolas Patry committed
386
387
388
389
390
391
392
393
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
394
395
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
396
397
398
399
400
401
402
403
404
405
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
409
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
410
        else:
Nicolas Patry's avatar
Nicolas Patry committed
411
412
413
414
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
415

Nicolas Patry's avatar
Nicolas Patry committed
416
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
470
471
472
473
474
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
475
476
477
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
478

drbh's avatar
drbh committed
479
480
481
482
483
484
485
486
487
488
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

489
490
491
492
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
493
494
495
496
497
498

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
499

500
501
502
    if max_input_tokens is not None and max_input_tokens <= sliding_window:
        sliding_window = -1

503
504
505
506
507
508
509
    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
510
        )
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
546
547
548
549
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
550
            speculator=speculator,
drbh's avatar
drbh committed
551
552
553
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
554

OlivierDehaene's avatar
OlivierDehaene committed
555
    if model_id.startswith("facebook/galactica"):
556
557
558
559
560
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
561
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
562
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
563
564
            dtype=dtype,
            trust_remote_code=trust_remote_code,
565
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
566
567
        )

568
    if (
569
570
        model_type == GPT_BIGCODE
        or model_type == GPT2
571
572
        and model_id.startswith("bigcode/")
    ):
573
        if FLASH_ATTENTION:
574
575
576
577
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
578
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
579
                speculator=speculator,
580
                dtype=dtype,
581
                trust_remote_code=trust_remote_code,
582
583
584
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
585
            )
586
587
588
589
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
590
        else:
591
592
593
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
594
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
595
                speculator=speculator,
596
                dtype=dtype,
597
598
                trust_remote_code=trust_remote_code,
            )
599

600
    if model_type == BLOOM:
601
602
603
604
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
605
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
606
            speculator=speculator,
607
608
            dtype=dtype,
            trust_remote_code=trust_remote_code,
609
            batch_class=BloomCausalLMBatch,
610
        )
611
    elif model_type == MPT:
612
613
614
615
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
616
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
617
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
618
619
            dtype=dtype,
            trust_remote_code=trust_remote_code,
620
            batch_class=CausalLMBatchKeysLast,
621
        )
622
    elif model_type == GPT2:
623
        if FLASH_ATTENTION:
624
            try:
625
626
627
628
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
629
630
631
632
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
633
                    lora_adapter_ids=lora_adapter_ids,
634
635
636
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
637
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
638
                return CausalLM.fallback(
639
640
641
642
643
644
645
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
646
647
648
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
649
            return CausalLM.fallback(
650
651
652
653
654
655
656
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
692
    elif model_type == GPT_NEOX:
693
        if FLASH_ATTENTION:
694
695
696
697
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

698
699
700
701
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
702
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
703
                speculator=speculator,
704
                dtype=dtype,
705
                trust_remote_code=trust_remote_code,
706
                lora_adapter_ids=lora_adapter_ids,
707
                config_class=GPTNeoXConfig,
708
709
            )
        elif sharded:
710
711
712
713
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
714
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
715
                speculator=speculator,
716
                dtype=dtype,
717
718
                trust_remote_code=trust_remote_code,
            )
719
        else:
720
            return CausalLM.fallback(
721
722
723
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
724
                speculator=speculator,
725
                dtype=dtype,
726
727
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
728

729
    elif model_type == PHI:
drbh's avatar
drbh committed
730
        if FLASH_ATTENTION:
731
732
733
734
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
735
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
736
                speculator=speculator,
drbh's avatar
drbh committed
737
738
                dtype=dtype,
                trust_remote_code=trust_remote_code,
739
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
740
741
            )
        else:
742
            return CausalLM.fallback(
drbh's avatar
drbh committed
743
744
745
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
746
                speculator=speculator,
drbh's avatar
drbh committed
747
748
749
750
751
752
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
753
754
755
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
756
        else:
757
758
759
760
761
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
762
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
763
                speculator=speculator,
drbh's avatar
drbh committed
764
765
766
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
767

768
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
769
        print(f">>> model_type: {model_type}")
770
        if FLASH_ATTENTION:
771
772
773
774
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
775
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
776
                speculator=speculator,
777
                dtype=dtype,
778
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
779
                lora_adapter_ids=lora_adapter_ids,
780
            )
781
782
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
783
        else:
784
            return CausalLM.fallback(
785
786
787
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
788
                speculator=speculator,
789
                dtype=dtype,
790
791
                trust_remote_code=trust_remote_code,
            )
792
    if model_type == GEMMA:
793
        if FLASH_ATTENTION:
794
795
796
797
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
798
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
799
                speculator=speculator,
800
                dtype=dtype,
801
802
                # Works better for these models
                default_dtype=torch.bfloat16,
803
                trust_remote_code=trust_remote_code,
804
                lora_adapter_ids=lora_adapter_ids,
805
806
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
807
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
808
        else:
809
            return CausalLM.fallback(
810
811
812
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
813
                speculator=speculator,
814
815
816
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
817
818
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
819
820
821
822
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
823
824
825
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
826
827
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
828
                trust_remote_code=trust_remote_code,
829
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
830
831
832
833
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
834
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
835
836
837
838
839
840
841
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
842

843
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
844
        if FLASH_ATTENTION:
845
846
847
848
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
849
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
850
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
851
852
                dtype=dtype,
                trust_remote_code=trust_remote_code,
853
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
854
855
856
857
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
858
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
859
860
861
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
862
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
863
864
865
866
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

867
    if model_type == DBRX:
868
        if FLASH_ATTENTION:
869
870
871
872
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
873
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
874
                speculator=speculator,
875
                dtype=dtype,
876
877
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
878
                trust_remote_code=trust_remote_code,
879
880
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
881
882
883
884
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
885
            return CausalLM.fallback(
886
887
888
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
889
                speculator=speculator,
890
891
892
893
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

894
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
895
896
        if sharded:
            if FLASH_ATTENTION:
897
                if config_dict.get("alibi", False):
898
                    raise NotImplementedError("sharded is not supported for this model")
899
900
901
902
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
903
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
904
                    speculator=speculator,
905
                    dtype=dtype,
906
907
908
909
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
910
                    trust_remote_code=trust_remote_code,
911
912
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
913
                )
914
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
915
        else:
916
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
917
918
919
920
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
921
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
922
                    speculator=speculator,
923
                    dtype=dtype,
924
925
926
927
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
928
                    trust_remote_code=trust_remote_code,
929
930
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
931
932
                )
            else:
933
                return CausalLM.fallback(
934
935
936
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
937
                    speculator=speculator,
938
                    dtype=dtype,
939
940
941
                    trust_remote_code=trust_remote_code,
                )

942
    if model_type == MISTRAL:
943
        if FLASH_ATTENTION:
944
945
946
947
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
948
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
949
                speculator=speculator,
950
951
                dtype=dtype,
                trust_remote_code=trust_remote_code,
952
                lora_adapter_ids=lora_adapter_ids,
953
            )
OlivierDehaene's avatar
OlivierDehaene committed
954
955
956
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
957
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
958
959
960
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
961
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
962
963
964
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
965

966
    if model_type == MIXTRAL:
967
        if FLASH_ATTENTION:
968
969
970
971
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
972
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
973
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
974
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
975
                trust_remote_code=trust_remote_code,
976
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
977
            )
OlivierDehaene's avatar
OlivierDehaene committed
978
979
980
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
981
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
982
983
984
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
985
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
986
987
988
989
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

990
    if model_type == STARCODER2:
991
        if FLASH_ATTENTION:
992
993
994
995
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
996
                quantize=quantize,
997
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
998
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
999
                trust_remote_code=trust_remote_code,
1000
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1001
1002
1003
1004
1005
1006
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1007
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1008
1009
1010
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1011
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1012
1013
1014
1015
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1016
    if model_type == QWEN2:
1017
        if FLASH_ATTENTION:
1018
1019
1020
1021
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1022
                quantize=quantize,
1023
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1024
1025
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1026
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1027
1028
1029
1030
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1031
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1032
1033
1034
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1035
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1036
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1037
1038
                trust_remote_code=trust_remote_code,
            )
1039

1040
    if model_type == OPT:
1041
1042
1043
1044
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1045
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1046
            speculator=speculator,
1047
1048
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1049
        )
1050

1051
    if model_type == T5:
1052
1053
1054
1055
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1056
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1057
            speculator=speculator,
1058
            dtype=dtype,
1059
            trust_remote_code=trust_remote_code,
1060
1061
1062
1063
1064
1065
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1066
        )
1067
    if model_type == IDEFICS:
1068
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1069
1070
1071
1072
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1073
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1074
1075
1076
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1077
1078
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1079
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1080
        if FLASH_ATTENTION:
1081
1082
1083
1084
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1085
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1086
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1087
1088
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1089
1090
1091
1092
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1093
1094
1095
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1096
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1097
        if FLASH_ATTENTION:
1098
1099
1100
1101
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1102
1103
1104
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1105
1106
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1107
                trust_remote_code=trust_remote_code,
1108
1109
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1110
1111
1112
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1113

1114
    if model_type == LLAVA_NEXT:
1115
        if FLASH_ATTENTION:
1116
1117
1118
1119
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1120
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1121
                speculator=speculator,
1122
1123
1124
1125
1126
1127
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1128
    if sharded:
1129
        raise NotImplementedError("sharded is not supported for AutoModel")
1130
    if quantize == "gptq":
1131
        raise NotImplementedError(
1132
1133
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1134
    if quantize == "awq":
1135
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1136
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1137
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1138
    elif quantize == "eetq":
1139
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1140
1141
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1142
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1143
        return CausalLM.fallback(
1144
1145
1146
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1147
            speculator=speculator,
1148
1149
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1150
        )
1151
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1152
        return Seq2SeqLM.fallback(
1153
1154
1155
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1156
            speculator=speculator,
1157
1158
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1159
1160
        )

1161
    auto_map = config_dict.get("auto_map", None)
1162
1163
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1164
            return CausalLM.fallback(
1165
1166
1167
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1168
                speculator=speculator,
1169
                dtype=dtype,
1170
1171
                trust_remote_code=trust_remote_code,
            )
1172
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1173
            return Seq2SeqLM.fallback(
1174
1175
1176
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1177
                speculator=speculator,
1178
                dtype=dtype,
1179
1180
                trust_remote_code=trust_remote_code,
            )
1181
1182

    raise ValueError(f"Unsupported model type {model_type}")
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model