"docs/vscode:/vscode.git/clone" did not exist on "1afc21855eb1f5575bd61037a7ee44522ccf401e"
__init__.py 44.6 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
35
36
37
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
38

39
40
41
42
43
44
45
46
47
48

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


49
from text_generation_server.utils.import_utils import SYSTEM
50
from text_generation_server.utils.log import log_master
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
66
    "get_model_with_lora_adapters",
67
68
]

69
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
70

71
FLASH_ATTENTION = True
72

73
try:
74
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
75
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
76
77
78
79
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
80
81
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
82
    )
83
84
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
91
    )
92
93
94
95
96
97
98
99
100
101
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
102
    )
drbh's avatar
drbh committed
103
    from text_generation_server.models.pali_gemma import (
104
        PaliGemmaBatch,
drbh's avatar
drbh committed
105
    )
106
107
108
109
110
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
111
    )
112
    from text_generation_server.models.idefics import IDEFICSSharded
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
135
136
137
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
138
139
140
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
141
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
142
except ImportError as e:
143
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
144
    SUPPORTS_WINDOWING = False
145
    FLASH_ATTENTION = False
146

147
if FLASH_ATTENTION:
148
    __all__.append(FlashCausalLM)
149
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
150

drbh's avatar
drbh committed
151
152
153
154
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
155
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
156
157
158
159
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
160

161

162
class ModelType(enum.Enum):
163
164
165
166
167
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
183
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
184
185
186
187
188
189
190
191
192
193
194
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
195
196
197
198
199
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
200
201
202
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
203
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
204
    }
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
223
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
258
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
300
301
302
303
304
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
305
306
307
308
309
310
311
312
313
314
315
316
317
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


318
def get_model(
319
    model_id: str,
drbh's avatar
drbh committed
320
    lora_adapter_ids: Optional[List[str]],
321
322
323
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
324
    speculate: Optional[int],
325
    dtype: Optional[str],
326
    trust_remote_code: bool,
327
    max_input_tokens: int,
328
) -> Model:
329
    global FLASH_ATTENTION
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

348
    if dtype is None:
349
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
350
351
            # These quantizers only work with float16 params.
            dtype = torch.float16
352
        elif quantize == "fp8":
353
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
354

355
            if FBGEMM_DYN_AVAILABLE:
356
357
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
358
359
360
361
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
362
363
364
365
366
367
368
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
374
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
375
    if "medusa_num_heads" in config_dict:
376
377
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
383
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
384
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
385
                )
Nicolas Patry's avatar
Nicolas Patry committed
386
387
388
389
390
391
392
393
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
394
395
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
396
397
398
399
400
401
402
403
404
405
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
409
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
410
        else:
Nicolas Patry's avatar
Nicolas Patry committed
411
412
413
414
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
415

Nicolas Patry's avatar
Nicolas Patry committed
416
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
461
462
463
464
465
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
466
467
468
469
470
471
472
473
474
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
475
476
477
478
479
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
480
481
482
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
483

drbh's avatar
drbh committed
484
485
486
487
488
489
490
491
492
493
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

494
495
496
497
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
498
499
500
501
502
503

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
504

505
506
507
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
508
    )
509
510
511
512
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
548
549
550
551
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
552
            speculator=speculator,
drbh's avatar
drbh committed
553
554
555
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
556

OlivierDehaene's avatar
OlivierDehaene committed
557
    if model_id.startswith("facebook/galactica"):
558
559
560
561
562
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
563
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
564
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
565
566
            dtype=dtype,
            trust_remote_code=trust_remote_code,
567
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
568
569
        )

570
    if (
571
572
        model_type == GPT_BIGCODE
        or model_type == GPT2
573
574
        and model_id.startswith("bigcode/")
    ):
575
        if FLASH_ATTENTION:
576
577
578
579
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
580
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
581
                speculator=speculator,
582
                dtype=dtype,
583
                trust_remote_code=trust_remote_code,
584
585
586
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
587
            )
588
589
590
591
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
592
        else:
593
594
595
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
596
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
597
                speculator=speculator,
598
                dtype=dtype,
599
600
                trust_remote_code=trust_remote_code,
            )
601

602
    if model_type == BLOOM:
603
604
605
606
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
607
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
608
            speculator=speculator,
609
610
            dtype=dtype,
            trust_remote_code=trust_remote_code,
611
            batch_class=BloomCausalLMBatch,
612
        )
613
    elif model_type == MPT:
614
615
616
617
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
618
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
619
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
620
621
            dtype=dtype,
            trust_remote_code=trust_remote_code,
622
            batch_class=CausalLMBatchKeysLast,
623
        )
624
    elif model_type == GPT2:
625
        if FLASH_ATTENTION:
626
            try:
627
628
629
630
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
631
632
633
634
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
635
                    lora_adapter_ids=lora_adapter_ids,
636
637
638
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
639
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
640
                return CausalLM.fallback(
641
642
643
644
645
646
647
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
648
649
650
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
651
            return CausalLM.fallback(
652
653
654
655
656
657
658
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
694
    elif model_type == GPT_NEOX:
695
        if FLASH_ATTENTION:
696
697
698
699
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

700
701
702
703
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
704
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
705
                speculator=speculator,
706
                dtype=dtype,
707
                trust_remote_code=trust_remote_code,
708
                lora_adapter_ids=lora_adapter_ids,
709
                config_class=GPTNeoXConfig,
710
711
            )
        elif sharded:
712
713
714
715
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
716
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
717
                speculator=speculator,
718
                dtype=dtype,
719
720
                trust_remote_code=trust_remote_code,
            )
721
        else:
722
            return CausalLM.fallback(
723
724
725
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
726
                speculator=speculator,
727
                dtype=dtype,
728
729
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
730

731
    elif model_type == PHI:
drbh's avatar
drbh committed
732
        if FLASH_ATTENTION:
733
734
735
736
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
737
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
738
                speculator=speculator,
drbh's avatar
drbh committed
739
740
                dtype=dtype,
                trust_remote_code=trust_remote_code,
741
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
742
743
            )
        else:
744
            return CausalLM.fallback(
drbh's avatar
drbh committed
745
746
747
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
748
                speculator=speculator,
drbh's avatar
drbh committed
749
750
751
752
753
754
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
755
756
757
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
758
        else:
759
760
761
762
763
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
764
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
765
                speculator=speculator,
drbh's avatar
drbh committed
766
767
768
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
769

770
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
771
        print(f">>> model_type: {model_type}")
772
        if FLASH_ATTENTION:
773
774
775
776
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
777
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
778
                speculator=speculator,
779
                dtype=dtype,
780
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
781
                lora_adapter_ids=lora_adapter_ids,
782
            )
783
784
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
785
        else:
786
            return CausalLM.fallback(
787
788
789
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
790
                speculator=speculator,
791
                dtype=dtype,
792
793
                trust_remote_code=trust_remote_code,
            )
794
    if model_type == GEMMA:
795
        if FLASH_ATTENTION:
796
797
798
799
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
800
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
801
                speculator=speculator,
802
                dtype=dtype,
803
804
                # Works better for these models
                default_dtype=torch.bfloat16,
805
                trust_remote_code=trust_remote_code,
806
                lora_adapter_ids=lora_adapter_ids,
807
808
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
809
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
810
        else:
811
            return CausalLM.fallback(
812
813
814
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
815
                speculator=speculator,
816
817
818
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
819
820
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
821
822
823
824
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
825
826
827
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
828
829
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
830
                trust_remote_code=trust_remote_code,
831
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
832
833
834
835
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
836
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
837
838
839
840
841
842
843
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
844

845
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
846
        if FLASH_ATTENTION:
847
848
849
850
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
851
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
852
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
853
854
                dtype=dtype,
                trust_remote_code=trust_remote_code,
855
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
856
857
858
859
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
860
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
861
862
863
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
864
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
865
866
867
868
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

869
    if model_type == DBRX:
870
        if FLASH_ATTENTION:
871
872
873
874
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
875
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
876
                speculator=speculator,
877
                dtype=dtype,
878
879
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
880
                trust_remote_code=trust_remote_code,
881
882
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
883
884
885
886
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
887
            return CausalLM.fallback(
888
889
890
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
891
                speculator=speculator,
892
893
894
895
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

896
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
897
898
        if sharded:
            if FLASH_ATTENTION:
899
                if config_dict.get("alibi", False):
900
                    raise NotImplementedError("sharded is not supported for this model")
901
902
903
904
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
905
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
906
                    speculator=speculator,
907
                    dtype=dtype,
908
909
910
911
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
912
                    trust_remote_code=trust_remote_code,
913
914
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
915
                )
916
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
917
        else:
918
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
919
920
921
922
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
923
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
924
                    speculator=speculator,
925
                    dtype=dtype,
926
927
928
929
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
930
                    trust_remote_code=trust_remote_code,
931
932
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
933
934
                )
            else:
935
                return CausalLM.fallback(
936
937
938
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
939
                    speculator=speculator,
940
                    dtype=dtype,
941
942
943
                    trust_remote_code=trust_remote_code,
                )

944
    if model_type == MISTRAL:
945
        if FLASH_ATTENTION:
946
947
948
949
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
950
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
951
                speculator=speculator,
952
953
                dtype=dtype,
                trust_remote_code=trust_remote_code,
954
                lora_adapter_ids=lora_adapter_ids,
955
            )
OlivierDehaene's avatar
OlivierDehaene committed
956
957
958
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
959
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
960
961
962
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
963
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
964
965
966
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
967

968
    if model_type == MIXTRAL:
969
        if FLASH_ATTENTION:
970
971
972
973
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
974
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
975
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
976
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
977
                trust_remote_code=trust_remote_code,
978
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
979
            )
OlivierDehaene's avatar
OlivierDehaene committed
980
981
982
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
983
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
984
985
986
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
987
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
988
989
990
991
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

992
    if model_type == STARCODER2:
993
        if FLASH_ATTENTION:
994
995
996
997
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
998
                quantize=quantize,
999
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1000
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1001
                trust_remote_code=trust_remote_code,
1002
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1003
1004
1005
1006
1007
1008
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1009
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1010
1011
1012
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1013
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1014
1015
1016
1017
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1018
    if model_type == QWEN2:
1019
        if FLASH_ATTENTION:
1020
1021
1022
1023
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1024
                quantize=quantize,
1025
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1026
1027
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1028
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1029
1030
1031
1032
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1033
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1034
1035
1036
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1037
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1038
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1039
1040
                trust_remote_code=trust_remote_code,
            )
1041

1042
    if model_type == OPT:
1043
1044
1045
1046
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1047
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1048
            speculator=speculator,
1049
1050
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1051
        )
1052

1053
    if model_type == T5:
1054
1055
1056
1057
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1058
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1059
            speculator=speculator,
1060
            dtype=dtype,
1061
            trust_remote_code=trust_remote_code,
1062
1063
1064
1065
1066
1067
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1068
        )
1069
    if model_type == IDEFICS:
1070
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1071
1072
1073
1074
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1075
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1076
1077
1078
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1079
1080
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1081
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1082
        if FLASH_ATTENTION:
1083
1084
1085
1086
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1087
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1088
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1089
1090
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1091
1092
1093
1094
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1095
1096
1097
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1098
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1099
        if FLASH_ATTENTION:
1100
1101
1102
1103
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1104
1105
1106
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1107
1108
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1109
                trust_remote_code=trust_remote_code,
1110
1111
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1112
1113
1114
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1115

1116
    if model_type == LLAVA_NEXT:
1117
        if FLASH_ATTENTION:
1118
1119
1120
1121
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1122
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1123
                speculator=speculator,
1124
1125
1126
1127
1128
1129
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1130
    if sharded:
1131
        raise NotImplementedError("sharded is not supported for AutoModel")
1132
    if quantize == "gptq":
1133
        raise NotImplementedError(
1134
1135
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1136
    if quantize == "awq":
1137
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1138
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1139
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1140
    elif quantize == "eetq":
1141
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1142
1143
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1144
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1145
        return CausalLM.fallback(
1146
1147
1148
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1149
            speculator=speculator,
1150
1151
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1152
        )
1153
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1154
        return Seq2SeqLM.fallback(
1155
1156
1157
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1158
            speculator=speculator,
1159
1160
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1161
1162
        )

1163
    auto_map = config_dict.get("auto_map", None)
1164
1165
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1166
            return CausalLM.fallback(
1167
1168
1169
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1170
                speculator=speculator,
1171
                dtype=dtype,
1172
1173
                trust_remote_code=trust_remote_code,
            )
1174
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1175
            return Seq2SeqLM.fallback(
1176
1177
1178
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1179
                speculator=speculator,
1180
                dtype=dtype,
1181
1182
                trust_remote_code=trust_remote_code,
            )
1183
1184

    raise ValueError(f"Unsupported model type {model_type}")
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
                    f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model