__init__.py 46.4 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
35
36
37
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
38
39
40
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
41

42
43
44
45
46
47
48
49
50
51

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


52
from text_generation_server.utils.import_utils import SYSTEM
53
from text_generation_server.utils.log import log_master
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
69
    "get_model_with_lora_adapters",
70
71
]

72
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
73

74
FLASH_ATTENTION = True
75

76
try:
77
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
78
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
79
80
81
82
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
83
84
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
85
    )
86
87
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
88
    )
89
90
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
91
    )
92
93
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
94
    )
95
96
97
98
99
100
101
102
103
104
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
105
    )
drbh's avatar
drbh committed
106
    from text_generation_server.models.pali_gemma import (
107
        PaliGemmaBatch,
drbh's avatar
drbh committed
108
    )
109
110
111
112
113
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
114
    )
115
    from text_generation_server.models.idefics import IDEFICSSharded
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
138
139
140
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
141
142
143
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
144
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
145
except ImportError as e:
146
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
147
    SUPPORTS_WINDOWING = False
148
    FLASH_ATTENTION = False
149

150
if FLASH_ATTENTION:
151
    __all__.append(FlashCausalLM)
152
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
153

drbh's avatar
drbh committed
154
155
156
157
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
158
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
159
160
161
162
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
163

164

165
class ModelType(enum.Enum):
166
167
168
169
170
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
186
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
187
188
189
190
191
192
193
194
195
196
197
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
198
199
200
201
202
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
203
204
205
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
206
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
207
    }
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
226
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
243
244
245
246
247
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
266
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
308
309
310
311
312
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
313
314
315
316
317
318
319
320
321
322
323
324
325
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


326
def get_model(
327
    model_id: str,
drbh's avatar
drbh committed
328
    lora_adapter_ids: Optional[List[str]],
329
330
331
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
332
    speculate: Optional[int],
333
    dtype: Optional[str],
334
    trust_remote_code: bool,
335
    max_input_tokens: int,
336
) -> Model:
337
    global FLASH_ATTENTION
338
339
340
341
342
343
344

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
345
    compression_config = config_dict.get("compression_config", None)
346
347
348
349
350
351
352
353
354
355
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
        elif method == "fbgemm_fp8":
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    elif compression_config is not None:
        # TODO: at some point we should probably fully parse the compression
        # configuration to know which parameters are compressed.
        config_groups = compression_config.get("config_groups")
        if config_groups is not None:
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
373

374
    if dtype is None:
375
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
376
377
            # These quantizers only work with float16 params.
            dtype = torch.float16
378
        elif quantize == "fp8":
379
            from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
380

381
            if FBGEMM_DYN_AVAILABLE:
382
383
                # fbgemm kernels are fp8xfp8->bf16
                dtype = torch.bfloat16
384
385
386
387
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
388
389
390
391
392
393
394
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
395
396
397
398
399
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
400
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
401
    if "medusa_num_heads" in config_dict:
402
403
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
404
405
406
407
408
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
409
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
410
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
411
                )
Nicolas Patry's avatar
Nicolas Patry committed
412
413
414
415
416
417
418
419
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
420
421
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
422
423
424
425
426
427
428
429
430
431
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
432
433
434
435
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
436
        else:
Nicolas Patry's avatar
Nicolas Patry committed
437
438
439
440
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
441

Nicolas Patry's avatar
Nicolas Patry committed
442
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
487
488
489
490
491
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
492
493
494
495
496
497
498
499
500
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
501
502
503
504
505
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
506
507
508
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
509

drbh's avatar
drbh committed
510
511
512
513
514
515
516
517
518
519
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

520
521
522
523
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
524
525
526
527
528
529

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
530

531
532
533
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
534
    )
535
536
537
538
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
574
575
576
577
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
578
            speculator=speculator,
drbh's avatar
drbh committed
579
580
581
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
582

OlivierDehaene's avatar
OlivierDehaene committed
583
    if model_id.startswith("facebook/galactica"):
584
585
586
587
588
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
589
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
590
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
591
592
            dtype=dtype,
            trust_remote_code=trust_remote_code,
593
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
594
595
        )

596
    if (
597
598
        model_type == GPT_BIGCODE
        or model_type == GPT2
599
600
        and model_id.startswith("bigcode/")
    ):
601
        if FLASH_ATTENTION:
602
603
604
605
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
606
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
607
                speculator=speculator,
608
                dtype=dtype,
609
                trust_remote_code=trust_remote_code,
610
611
612
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
613
            )
614
615
616
617
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
618
        else:
619
620
621
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
622
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
623
                speculator=speculator,
624
                dtype=dtype,
625
626
                trust_remote_code=trust_remote_code,
            )
627

628
    if model_type == BLOOM:
629
630
631
632
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
633
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
634
            speculator=speculator,
635
636
            dtype=dtype,
            trust_remote_code=trust_remote_code,
637
            batch_class=BloomCausalLMBatch,
638
        )
639
    elif model_type == MPT:
640
641
642
643
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
644
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
645
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
646
647
            dtype=dtype,
            trust_remote_code=trust_remote_code,
648
            batch_class=CausalLMBatchKeysLast,
649
        )
650
    elif model_type == GPT2:
651
        if FLASH_ATTENTION:
652
            try:
653
654
655
656
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
657
658
659
660
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
661
                    lora_adapter_ids=lora_adapter_ids,
662
663
664
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
665
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
666
                return CausalLM.fallback(
667
668
669
670
671
672
673
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
674
675
676
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
677
            return CausalLM.fallback(
678
679
680
681
682
683
684
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
720
    elif model_type == GPT_NEOX:
721
        if FLASH_ATTENTION:
722
723
724
725
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

726
727
728
729
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
730
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
731
                speculator=speculator,
732
                dtype=dtype,
733
                trust_remote_code=trust_remote_code,
734
                lora_adapter_ids=lora_adapter_ids,
735
                config_class=GPTNeoXConfig,
736
737
            )
        elif sharded:
738
739
740
741
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
742
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
743
                speculator=speculator,
744
                dtype=dtype,
745
746
                trust_remote_code=trust_remote_code,
            )
747
        else:
748
            return CausalLM.fallback(
749
750
751
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
752
                speculator=speculator,
753
                dtype=dtype,
754
755
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
756

757
    elif model_type == PHI:
drbh's avatar
drbh committed
758
        if FLASH_ATTENTION:
759
760
761
762
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
763
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
764
                speculator=speculator,
drbh's avatar
drbh committed
765
766
                dtype=dtype,
                trust_remote_code=trust_remote_code,
767
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
768
769
            )
        else:
770
            return CausalLM.fallback(
drbh's avatar
drbh committed
771
772
773
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
774
                speculator=speculator,
drbh's avatar
drbh committed
775
776
777
778
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
802
803
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
804
805
806
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
807
        else:
808
809
810
811
812
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
813
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
814
                speculator=speculator,
drbh's avatar
drbh committed
815
816
817
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
818

819
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
820
        if FLASH_ATTENTION:
821
822
823
824
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
825
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
826
                speculator=speculator,
827
                dtype=dtype,
828
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
829
                lora_adapter_ids=lora_adapter_ids,
830
            )
831
832
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
833
        else:
834
            return CausalLM.fallback(
835
836
837
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
838
                speculator=speculator,
839
                dtype=dtype,
840
841
                trust_remote_code=trust_remote_code,
            )
842
    if model_type == GEMMA:
843
        if FLASH_ATTENTION:
844
845
846
847
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
848
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
849
                speculator=speculator,
850
                dtype=dtype,
851
852
                # Works better for these models
                default_dtype=torch.bfloat16,
853
                trust_remote_code=trust_remote_code,
854
                lora_adapter_ids=lora_adapter_ids,
855
856
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
857
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
858
        else:
859
            return CausalLM.fallback(
860
861
862
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
863
                speculator=speculator,
864
865
866
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
867
868
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
869
870
871
872
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
873
874
875
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
876
877
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
878
                trust_remote_code=trust_remote_code,
879
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
880
881
882
883
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
884
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
885
886
887
888
889
890
891
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
892

893
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
894
        if FLASH_ATTENTION:
895
896
897
898
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
899
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
900
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
901
902
                dtype=dtype,
                trust_remote_code=trust_remote_code,
903
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
904
905
906
907
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
908
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
909
910
911
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
912
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
913
914
915
916
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

917
    if model_type == DBRX:
918
        if FLASH_ATTENTION:
919
920
921
922
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
923
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
924
                speculator=speculator,
925
                dtype=dtype,
926
927
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
928
                trust_remote_code=trust_remote_code,
929
930
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
931
932
933
934
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
935
            return CausalLM.fallback(
936
937
938
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
939
                speculator=speculator,
940
941
942
943
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

944
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
945
946
        if sharded:
            if FLASH_ATTENTION:
947
                if config_dict.get("alibi", False):
948
                    raise NotImplementedError("sharded is not supported for this model")
949
950
951
952
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
953
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
954
                    speculator=speculator,
955
                    dtype=dtype,
956
957
958
959
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
960
                    trust_remote_code=trust_remote_code,
961
962
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
963
                )
964
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
965
        else:
966
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
967
968
969
970
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
971
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
972
                    speculator=speculator,
973
                    dtype=dtype,
974
975
976
977
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
978
                    trust_remote_code=trust_remote_code,
979
980
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
981
982
                )
            else:
983
                return CausalLM.fallback(
984
985
986
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
987
                    speculator=speculator,
988
                    dtype=dtype,
989
990
991
                    trust_remote_code=trust_remote_code,
                )

992
    if model_type == MISTRAL:
993
        if FLASH_ATTENTION:
994
995
996
997
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
998
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
999
                speculator=speculator,
1000
1001
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1002
                lora_adapter_ids=lora_adapter_ids,
1003
            )
OlivierDehaene's avatar
OlivierDehaene committed
1004
1005
1006
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1007
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1008
1009
1010
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1011
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1012
1013
1014
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1015

1016
    if model_type == MIXTRAL:
1017
        if FLASH_ATTENTION:
1018
1019
1020
1021
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1022
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1023
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1024
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1025
                trust_remote_code=trust_remote_code,
1026
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1027
            )
OlivierDehaene's avatar
OlivierDehaene committed
1028
1029
1030
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1031
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1032
1033
1034
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1035
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1036
1037
1038
1039
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1040
    if model_type == STARCODER2:
1041
        if FLASH_ATTENTION:
1042
1043
1044
1045
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1046
                quantize=quantize,
1047
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1048
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1049
                trust_remote_code=trust_remote_code,
1050
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1051
1052
1053
1054
1055
1056
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1057
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1058
1059
1060
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1061
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1062
1063
1064
1065
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1066
    if model_type == QWEN2:
1067
        if FLASH_ATTENTION:
1068
1069
1070
1071
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1072
                quantize=quantize,
1073
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1074
1075
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1076
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1077
1078
1079
1080
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1081
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1082
1083
1084
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1085
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1086
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1087
1088
                trust_remote_code=trust_remote_code,
            )
1089

1090
    if model_type == OPT:
1091
1092
1093
1094
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1095
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1096
            speculator=speculator,
1097
1098
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1099
        )
1100

1101
    if model_type == T5:
1102
1103
1104
1105
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1106
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1107
            speculator=speculator,
1108
            dtype=dtype,
1109
            trust_remote_code=trust_remote_code,
1110
1111
1112
1113
1114
1115
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1116
        )
1117
    if model_type == IDEFICS:
1118
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
1119
1120
1121
1122
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1123
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1124
1125
1126
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1127
1128
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1129
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1130
        if FLASH_ATTENTION:
1131
1132
1133
1134
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1135
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1136
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1137
1138
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1139
1140
1141
1142
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1143
1144
1145
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1146
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1147
        if FLASH_ATTENTION:
1148
1149
1150
1151
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1152
1153
1154
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1155
1156
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1157
                trust_remote_code=trust_remote_code,
1158
1159
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1160
1161
1162
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1163

1164
    if model_type == LLAVA_NEXT:
1165
        if FLASH_ATTENTION:
1166
1167
1168
1169
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1170
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1171
                speculator=speculator,
1172
1173
1174
1175
1176
1177
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1178
    if sharded:
1179
        raise NotImplementedError("sharded is not supported for AutoModel")
1180
    if quantize == "gptq":
1181
        raise NotImplementedError(
1182
1183
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1184
    if quantize == "awq":
1185
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1186
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1187
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1188
    elif quantize == "eetq":
1189
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1190
1191
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1192
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1193
        return CausalLM.fallback(
1194
1195
1196
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1197
            speculator=speculator,
1198
1199
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1200
        )
1201
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1202
        return Seq2SeqLM.fallback(
1203
1204
1205
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1206
            speculator=speculator,
1207
1208
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1209
1210
        )

1211
    auto_map = config_dict.get("auto_map", None)
1212
1213
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1214
            return CausalLM.fallback(
1215
1216
1217
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1218
                speculator=speculator,
1219
                dtype=dtype,
1220
1221
                trust_remote_code=trust_remote_code,
            )
1222
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1223
            return Seq2SeqLM.fallback(
1224
1225
1226
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1227
                speculator=speculator,
1228
                dtype=dtype,
1229
1230
                trust_remote_code=trust_remote_code,
            )
1231
1232

    raise ValueError(f"Unsupported model type {model_type}")
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1310
                "qkv_proj",
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1338
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1339
1340
1341
1342
1343
1344
1345
1346
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model