__init__.py 49.7 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
35
36
37
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
38
39
40
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
41

42
43
44
45
46
47
48
49
50
51

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


52
from text_generation_server.utils.import_utils import SYSTEM
53
from text_generation_server.utils.log import log_master
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
69
    "get_model_with_lora_adapters",
70
71
]

72
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
73

74
FLASH_ATTENTION = True
75

76
try:
77
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
78
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
Nicolas Patry's avatar
Nicolas Patry committed
79
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
80
81
82
83
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
84
85
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
86
    )
87
88
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
89
    )
90
91
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
92
    )
93
94
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
95
    )
96
97
98
99
100
101
102
103
104
105
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
106
    )
drbh's avatar
drbh committed
107
    from text_generation_server.models.pali_gemma import (
108
        PaliGemmaBatch,
drbh's avatar
drbh committed
109
    )
110
111
112
113
114
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
115
    )
Nicolas Patry's avatar
Nicolas Patry committed
116
117
118
119
120
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
    from text_generation_server.models.custom_modeling.mllama import (
        MllamaForConditionalGeneration,
    )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
143
144
145
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
146
147
148
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
drbh's avatar
drbh committed
149
150
151
    from text_generation_server.models.custom_modeling.qwen2_vl import (
        Qwen2VLForConditionalGeneration,
    )
152
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
153
except ImportError as e:
154
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
155
    SUPPORTS_WINDOWING = False
156
    FLASH_ATTENTION = False
157

158
if FLASH_ATTENTION:
159
    __all__.append(FlashCausalLM)
Nicolas Patry's avatar
Nicolas Patry committed
160
    __all__.append(IdeficsCausalLM)
OlivierDehaene's avatar
OlivierDehaene committed
161

drbh's avatar
drbh committed
162
163
164
165
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
166
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
167
168
169
170
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
171

172

173
class ModelType(enum.Enum):
174
175
176
177
178
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
194
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
195
196
197
198
199
200
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
201
202
203
204
205
    GRANITE = {
        "type": "granite",
        "name": "Granite",
        "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
    }
206
207
208
209
210
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
211
212
213
214
215
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
216
217
218
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
219
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
220
    }
221
222
223
224
225
226
227
228
229
230
231
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
232
        "type": "mamba",
233
234
235
236
237
238
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
239
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
256
257
258
259
260
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
279
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
280
    }
drbh's avatar
drbh committed
281
282
283
284
285
    QWEN2_VL = {
        "type": "qwen2_vl",
        "name": "Qwen 2 VL",
        "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
    }
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
326
327
328
329
330
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
331
332
333
334
335
336
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }
Nicolas Patry's avatar
Nicolas Patry committed
337
338
339
340
341
342
    MLLAMA = {
        "type": "mllama",
        "name": "Mllama",
        "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
        "multimodal": True,
    }
343
344
345
346
347
348
349


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


350
def get_model(
351
    model_id: str,
drbh's avatar
drbh committed
352
    lora_adapter_ids: Optional[List[str]],
353
354
355
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
356
    speculate: Optional[int],
357
    dtype: Optional[str],
358
    kv_cache_dtype: Optional[str],
359
    trust_remote_code: bool,
360
    max_input_tokens: int,
361
) -> Model:
362
    global FLASH_ATTENTION
363
364
365
366
367
368
369

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
370
    compression_config = config_dict.get("compression_config", None)
371
372
373
374
375
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
376
        elif method == "fbgemm_fp8" or method == "fp8":
377
378
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
379
380
381
382
383
        if method == "compressed-tensors":
            log_master(
                logger.info, "Auto selecting quantization method compressed-tensors"
            )
            quantize = "compressed-tensors"
384
385
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")
386
    elif compression_config is not None:
387
        # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
388
389
        log_master(logger.info, "Auto selecting quantization method compressed-tensors")
        quantize = "compressed-tensors"
390

391
    if dtype is None:
392
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
Nicolas Patry's avatar
Nicolas Patry committed
393
394
395
396
397
            if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
                dtype = torch.bfloat16
            else:
                # These quantizers only work with float16 params.
                dtype = torch.float16
398
399
400
401
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
402
403
404
405
406
407
408
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

409
410
    if kv_cache_dtype is None:
        kv_cache_dtype = dtype
411
412
    elif kv_cache_dtype == "fp8_e4m3fn":
        kv_cache_dtype = torch.float8_e4m3fn
413
414
415
416
417
    elif kv_cache_dtype == "fp8_e5m2":
        kv_cache_dtype = torch.float8_e5m2
    else:
        raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
418
419
420
421
422
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
423
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
424
    if "medusa_num_heads" in config_dict:
425
426
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
427
428
429
430
431
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
432
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
433
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
434
                )
Nicolas Patry's avatar
Nicolas Patry committed
435
436
437
438
439
440
441
442
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
443
444
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
445
446
447
448
449
450
451
452
453
454
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
455
456
457
458
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
459
        else:
Nicolas Patry's avatar
Nicolas Patry committed
460
461
462
463
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
464

Nicolas Patry's avatar
Nicolas Patry committed
465
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
510
511
512
513
514
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
515
516
517
518
519
520
521
522
523
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
524
525
526
527
528
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
529
530
531
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
532

drbh's avatar
drbh committed
533
534
535
536
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
537
            model_type = "mamba"
drbh's avatar
drbh committed
538
539
540
541
542
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

543
544
545
546
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
547
548
549
550
551
552

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
553

554
555
556
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
557
    )
558
559
560
561
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
578
                kv_cache_dtype=kv_cache_dtype,
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
598
599
600
601
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
602
            speculator=speculator,
drbh's avatar
drbh committed
603
604
605
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
606
607
608
609
    elif model_type == "ssm":
        raise RuntimeError(
            "`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf"
        )
610

OlivierDehaene's avatar
OlivierDehaene committed
611
    if model_id.startswith("facebook/galactica"):
612
613
614
615
616
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
617
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
618
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
619
620
            dtype=dtype,
            trust_remote_code=trust_remote_code,
621
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
622
623
        )

624
    if (
625
626
        model_type == GPT_BIGCODE
        or model_type == GPT2
627
628
        and model_id.startswith("bigcode/")
    ):
629
        if FLASH_ATTENTION:
630
631
632
633
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
634
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
635
                speculator=speculator,
636
                dtype=dtype,
637
                kv_cache_dtype=kv_cache_dtype,
638
                trust_remote_code=trust_remote_code,
639
640
641
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
642
            )
643
644
645
646
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
647
        else:
648
649
650
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
651
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
652
                speculator=speculator,
653
                dtype=dtype,
654
655
                trust_remote_code=trust_remote_code,
            )
656

657
    if model_type == BLOOM:
658
659
660
661
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
662
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
663
            speculator=speculator,
664
665
            dtype=dtype,
            trust_remote_code=trust_remote_code,
666
            batch_class=BloomCausalLMBatch,
667
        )
668
    elif model_type == MPT:
669
670
671
672
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
673
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
674
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
675
676
            dtype=dtype,
            trust_remote_code=trust_remote_code,
677
            batch_class=CausalLMBatchKeysLast,
678
        )
679
    elif model_type == GPT2:
680
        if FLASH_ATTENTION:
681
            try:
682
683
684
685
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
686
687
688
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
689
                    kv_cache_dtype=kv_cache_dtype,
690
                    trust_remote_code=trust_remote_code,
691
                    lora_adapter_ids=lora_adapter_ids,
692
693
694
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
695
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
696
                return CausalLM.fallback(
697
698
699
700
701
702
703
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
704
705
706
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
707
            return CausalLM.fallback(
708
709
710
711
712
713
714
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
715
716
717
718
719
720
721
722
723
724
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
725
                    kv_cache_dtype=kv_cache_dtype,
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
751
    elif model_type == GPT_NEOX:
752
        if FLASH_ATTENTION:
753
754
755
756
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

757
758
759
760
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
761
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
762
                speculator=speculator,
763
                dtype=dtype,
764
                kv_cache_dtype=kv_cache_dtype,
765
                trust_remote_code=trust_remote_code,
766
                lora_adapter_ids=lora_adapter_ids,
767
                config_class=GPTNeoXConfig,
768
769
            )
        elif sharded:
770
771
772
773
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
774
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
775
                speculator=speculator,
776
                dtype=dtype,
777
778
                trust_remote_code=trust_remote_code,
            )
779
        else:
780
            return CausalLM.fallback(
781
782
783
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
784
                speculator=speculator,
785
                dtype=dtype,
786
787
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
788

789
    elif model_type == PHI:
drbh's avatar
drbh committed
790
        if FLASH_ATTENTION:
791
792
793
794
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
795
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
796
                speculator=speculator,
drbh's avatar
drbh committed
797
                dtype=dtype,
798
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
799
                trust_remote_code=trust_remote_code,
800
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
801
802
            )
        else:
803
            return CausalLM.fallback(
drbh's avatar
drbh committed
804
805
806
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
807
                speculator=speculator,
drbh's avatar
drbh committed
808
809
810
811
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
812
813
814
815
816
817
818
819
820
821
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
822
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
823
824
825
826
827
828
829
830
831
832
833
834
835
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
836
837
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
838
839
840
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
841
        else:
842
843
844
845
846
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
847
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
848
                speculator=speculator,
drbh's avatar
drbh committed
849
850
851
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
852

853
854
855
856
857
858
    elif (
        model_type == LLAMA
        or model_type == BAICHUAN
        or model_type == PHI3
        or model_type == GRANITE
    ):
859
        if FLASH_ATTENTION:
860
861
862
863
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
864
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
865
                speculator=speculator,
866
                dtype=dtype,
867
                kv_cache_dtype=kv_cache_dtype,
868
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
869
                lora_adapter_ids=lora_adapter_ids,
870
            )
871
        elif sharded:
872
873
874
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
            )
875
        else:
876
            return CausalLM.fallback(
877
878
879
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
880
                speculator=speculator,
881
                dtype=dtype,
882
883
                trust_remote_code=trust_remote_code,
            )
884
    if model_type == GEMMA:
885
        if FLASH_ATTENTION:
886
887
888
889
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
890
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
891
                speculator=speculator,
892
                dtype=dtype,
893
                kv_cache_dtype=kv_cache_dtype,
894
895
                # Works better for these models
                default_dtype=torch.bfloat16,
896
                trust_remote_code=trust_remote_code,
897
                lora_adapter_ids=lora_adapter_ids,
898
899
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
900
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
901
        else:
902
            return CausalLM.fallback(
903
904
905
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
906
                speculator=speculator,
907
908
909
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
910
911
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
912
913
914
915
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
916
917
918
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
919
                kv_cache_dtype=kv_cache_dtype,
920
921
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
922
                trust_remote_code=trust_remote_code,
923
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
924
925
926
927
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
928
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
929
930
931
932
933
934
935
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
936

937
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
938
        if FLASH_ATTENTION:
939
940
941
942
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
943
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
944
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
945
                dtype=dtype,
946
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
947
                trust_remote_code=trust_remote_code,
948
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
949
950
951
952
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
953
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
954
955
956
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
957
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
958
959
960
961
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

962
    if model_type == DBRX:
963
        if FLASH_ATTENTION:
964
965
966
967
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
968
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
969
                speculator=speculator,
970
                dtype=dtype,
971
                kv_cache_dtype=kv_cache_dtype,
972
973
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
974
                trust_remote_code=trust_remote_code,
975
976
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
977
978
979
980
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
981
            return CausalLM.fallback(
982
983
984
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
985
                speculator=speculator,
986
987
988
989
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

990
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
991
992
        if sharded:
            if FLASH_ATTENTION:
993
                if config_dict.get("alibi", False):
994
                    raise NotImplementedError("sharded is not supported for this model")
995
996
997
998
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
999
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1000
                    speculator=speculator,
1001
                    dtype=dtype,
1002
                    kv_cache_dtype=kv_cache_dtype,
1003
1004
1005
1006
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1007
                    trust_remote_code=trust_remote_code,
1008
1009
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1010
                )
1011
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
1012
        else:
1013
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
1014
1015
1016
1017
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1018
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1019
                    speculator=speculator,
1020
                    dtype=dtype,
1021
                    kv_cache_dtype=kv_cache_dtype,
1022
1023
1024
1025
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1026
                    trust_remote_code=trust_remote_code,
1027
1028
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1029
1030
                )
            else:
1031
                return CausalLM.fallback(
1032
1033
1034
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1035
                    speculator=speculator,
1036
                    dtype=dtype,
1037
1038
1039
                    trust_remote_code=trust_remote_code,
                )

1040
    if model_type == MISTRAL:
1041
        if FLASH_ATTENTION:
1042
1043
1044
1045
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
1046
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1047
                speculator=speculator,
1048
                dtype=dtype,
1049
                kv_cache_dtype=kv_cache_dtype,
1050
                trust_remote_code=trust_remote_code,
1051
                lora_adapter_ids=lora_adapter_ids,
1052
            )
OlivierDehaene's avatar
OlivierDehaene committed
1053
1054
1055
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1056
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1057
1058
1059
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1060
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1061
1062
1063
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1064

1065
    if model_type == MIXTRAL:
1066
        if FLASH_ATTENTION:
1067
1068
1069
1070
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1071
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1072
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1073
                dtype=dtype,
1074
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1075
                trust_remote_code=trust_remote_code,
1076
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1077
            )
OlivierDehaene's avatar
OlivierDehaene committed
1078
1079
1080
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1081
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1082
1083
1084
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1085
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1086
1087
1088
1089
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1090
    if model_type == STARCODER2:
1091
        if FLASH_ATTENTION:
1092
1093
1094
1095
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1096
                quantize=quantize,
1097
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1098
                dtype=dtype,
1099
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1100
                trust_remote_code=trust_remote_code,
1101
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1102
1103
1104
1105
1106
1107
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1108
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1109
1110
1111
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1112
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1113
1114
1115
1116
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1117
    if model_type == QWEN2:
1118
        if FLASH_ATTENTION:
1119
1120
1121
1122
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1123
                quantize=quantize,
1124
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1125
                dtype=dtype,
1126
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1127
                trust_remote_code=trust_remote_code,
1128
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1129
1130
1131
1132
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1133
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1134
1135
1136
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1137
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1138
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1139
1140
                trust_remote_code=trust_remote_code,
            )
1141

1142
    if model_type == OPT:
1143
1144
1145
1146
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1147
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1148
            speculator=speculator,
1149
1150
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1151
        )
1152

1153
    if model_type == T5:
1154
1155
1156
1157
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1158
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1159
            speculator=speculator,
1160
            dtype=dtype,
1161
            trust_remote_code=trust_remote_code,
1162
1163
1164
1165
1166
1167
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1168
        )
1169
    if model_type == IDEFICS:
1170
        if FLASH_ATTENTION:
Nicolas Patry's avatar
Nicolas Patry committed
1171
            return IdeficsCausalLM(
OlivierDehaene's avatar
OlivierDehaene committed
1172
1173
1174
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1175
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1176
1177
1178
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1179
1180
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    if model_type == QWEN2_VL:
        return VlmCausalLM(
            model_id=model_id,
            model_class=Qwen2VLForConditionalGeneration,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            kv_cache_dtype=kv_cache_dtype,
            trust_remote_code=trust_remote_code,
            lora_adapter_ids=lora_adapter_ids,
        )
Nicolas Patry's avatar
Nicolas Patry committed
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    if model_type == MLLAMA:
        if FLASH_ATTENTION:
            return MllamaCausalLM(
                model_id=model_id,
                model_class=MllamaForConditionalGeneration,
                batch_class=MllamaCausalLMBatch,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                default_dtype=torch.bfloat16,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
1209
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1210
        if FLASH_ATTENTION:
1211
1212
1213
1214
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1215
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1216
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1217
                dtype=dtype,
1218
                kv_cache_dtype=kv_cache_dtype,
Nicolas Patry's avatar
Nicolas Patry committed
1219
                trust_remote_code=trust_remote_code,
1220
1221
1222
1223
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1224
1225
1226
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1227
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1228
        if FLASH_ATTENTION:
1229
1230
1231
1232
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1233
1234
1235
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1236
                kv_cache_dtype=kv_cache_dtype,
1237
1238
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1239
                trust_remote_code=trust_remote_code,
1240
1241
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1242
1243
1244
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1245

1246
    if model_type == LLAVA_NEXT:
1247
        if FLASH_ATTENTION:
1248
1249
1250
1251
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1252
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1253
                speculator=speculator,
1254
                dtype=dtype,
1255
                kv_cache_dtype=kv_cache_dtype,
1256
1257
1258
1259
1260
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1261
    if sharded:
1262
        raise NotImplementedError("sharded is not supported for AutoModel")
1263
    if quantize == "gptq":
1264
        raise NotImplementedError(
1265
1266
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1267
    if quantize == "awq":
1268
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1269
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1270
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1271
    elif quantize == "eetq":
1272
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1273
1274
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1275
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1276
        return CausalLM.fallback(
1277
1278
1279
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1280
            speculator=speculator,
1281
1282
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1283
        )
1284
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1285
        return Seq2SeqLM.fallback(
1286
1287
1288
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1289
            speculator=speculator,
1290
1291
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1292
1293
        )

1294
    auto_map = config_dict.get("auto_map", None)
1295
1296
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1297
            return CausalLM.fallback(
1298
1299
1300
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1301
                speculator=speculator,
1302
                dtype=dtype,
1303
1304
                trust_remote_code=trust_remote_code,
            )
1305
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1306
            return Seq2SeqLM.fallback(
1307
1308
1309
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1310
                speculator=speculator,
1311
                dtype=dtype,
1312
1313
                trust_remote_code=trust_remote_code,
            )
1314
1315

    raise ValueError(f"Unsupported model type {model_type}")
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
1328
    kv_cache_dtype: Optional[str],
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
1342
        kv_cache_dtype,
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1395
                "qkv_proj",
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1423
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1424
1425
1426
1427
1428
1429
1430
1431
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model