__init__.py 50.6 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
5
6
7
8
from compressed_tensors.compressors.model_compressors.model_compressor import (
    QuantizationConfig,
)
from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError
9
import torch
10
import enum
Nicolas Patry's avatar
Nicolas Patry committed
11
import os
12

13
from loguru import logger
14
from transformers.configuration_utils import PretrainedConfig
15
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
16
from huggingface_hub import hf_hub_download, HfApi
17
from typing import Optional, List, Dict
18
from pathlib import Path
19

Nicolas Patry's avatar
Nicolas Patry committed
20
from text_generation_server.utils.speculate import get_speculate, set_speculate
21
from text_generation_server.models.model import Model
22
23
24
25
26
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
27
from text_generation_server.models.bloom import BloomCausalLMBatch
28
29
30
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
31
from text_generation_server.models.globals import ATTENTION
32
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
33
34
35
36
37
38
39
40
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
41
42
43
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
44
45
46
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
47

48
49
50
51
52
53
54
55
56
57

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


58
from text_generation_server.utils.import_utils import SYSTEM
59
from text_generation_server.utils.log import log_master
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
75
    "get_model_with_lora_adapters",
76
77
]

78
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
79

80
FLASH_ATTENTION = True
81

82
try:
83
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
84
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
Nicolas Patry's avatar
Nicolas Patry committed
85
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
86
87
88
89
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
90
91
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
92
    )
93
94
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
95
    )
96
97
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
98
    )
99
100
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
101
    )
102
103
104
105
106
107
108
109
110
111
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
112
    )
drbh's avatar
drbh committed
113
    from text_generation_server.models.pali_gemma import (
114
        PaliGemmaBatch,
drbh's avatar
drbh committed
115
    )
116
117
118
119
120
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
121
    )
Nicolas Patry's avatar
Nicolas Patry committed
122
123
124
125
126
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
    from text_generation_server.models.custom_modeling.mllama import (
        MllamaForConditionalGeneration,
    )
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
149
150
151
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
152
153
154
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
drbh's avatar
drbh committed
155
156
157
    from text_generation_server.models.custom_modeling.qwen2_vl import (
        Qwen2VLForConditionalGeneration,
    )
158
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
159
except ImportError as e:
160
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
161
    SUPPORTS_WINDOWING = False
162
    FLASH_ATTENTION = False
163

164
if FLASH_ATTENTION:
165
    __all__.append(FlashCausalLM)
Nicolas Patry's avatar
Nicolas Patry committed
166
    __all__.append(IdeficsCausalLM)
OlivierDehaene's avatar
OlivierDehaene committed
167

drbh's avatar
drbh committed
168
169
170
171
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
172
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
173
174
175
176
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
177

178

179
class ModelType(enum.Enum):
180
181
182
183
184
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
200
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
201
202
203
204
205
206
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
207
208
209
210
211
    GRANITE = {
        "type": "granite",
        "name": "Granite",
        "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
    }
212
213
214
215
216
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
217
218
219
220
221
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
222
223
224
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
225
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
226
    }
227
228
229
230
231
232
233
234
235
236
237
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
238
        "type": "mamba",
239
240
241
242
243
244
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
245
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
262
263
264
265
266
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
285
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
286
    }
drbh's avatar
drbh committed
287
288
289
290
291
    QWEN2_VL = {
        "type": "qwen2_vl",
        "name": "Qwen 2 VL",
        "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
    }
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
332
333
334
335
336
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
337
338
339
340
341
342
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }
Nicolas Patry's avatar
Nicolas Patry committed
343
344
345
346
347
348
    MLLAMA = {
        "type": "mllama",
        "name": "Mllama",
        "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
        "multimodal": True,
    }
349
350
351
352
353
354
355


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


356
def get_model(
357
    model_id: str,
drbh's avatar
drbh committed
358
    lora_adapter_ids: Optional[List[str]],
359
360
361
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
362
    speculate: Optional[int],
363
    dtype: Optional[str],
364
    kv_cache_dtype: Optional[str],
365
    trust_remote_code: bool,
366
    max_input_tokens: int,
367
) -> Model:
368
    global FLASH_ATTENTION
369
370
371
372
373
374
375

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
376
377
    if quantization_config is None:
        quantization_config = config_dict.get("compression_config", None)
378
379
380
381
382
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
383
        elif method == "fbgemm_fp8" or method == "fp8":
384
385
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
386
387
388
389
390
        if method == "compressed-tensors":
            log_master(
                logger.info, "Auto selecting quantization method compressed-tensors"
            )
            quantize = "compressed-tensors"
391

392
393
394
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")

395
    if dtype is None:
396
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
397
398
399
            if SYSTEM == "ipex" and not (
                hasattr(torch, "xpu") and torch.xpu.is_available()
            ):
Nicolas Patry's avatar
Nicolas Patry committed
400
401
402
403
                dtype = torch.bfloat16
            else:
                # These quantizers only work with float16 params.
                dtype = torch.float16
404
405
406
407
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
408
409
410
411
412
413
414
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

415
416
417
418
419
420
421
422
423
    compressed_tensors_config = None
    if quantize == "compressed-tensors":
        try:
            compressed_tensors_config = QuantizationConfig.model_validate(
                quantization_config
            )
        except ValidationError as e:
            raise ValueError("Cannot parse compressed-tensors configuration") from e

424
    if kv_cache_dtype is None:
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        kv_cache_scheme = (
            compressed_tensors_config.kv_cache_scheme
            if isinstance(compressed_tensors_config, QuantizationConfig)
            else None
        )
        if (
            kv_cache_scheme is not None
            and kv_cache_scheme.type == QuantizationType.FLOAT
            and kv_cache_scheme.num_bits == 8
            and SYSTEM == "cuda"
            and ATTENTION == "flashinfer"
        ):
            kv_cache_dtype = torch.float8_e4m3fn
        else:
            kv_cache_dtype = dtype
440
441
    elif kv_cache_dtype == "fp8_e4m3fn":
        kv_cache_dtype = torch.float8_e4m3fn
442
443
444
445
446
    elif kv_cache_dtype == "fp8_e5m2":
        kv_cache_dtype = torch.float8_e5m2
    else:
        raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
447
448
449
450
451
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
452
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
453
    if "medusa_num_heads" in config_dict:
454
455
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
456
457
458
459
460
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
461
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
462
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
463
                )
Nicolas Patry's avatar
Nicolas Patry committed
464
465
466
467
468
469
470
471
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
472
473
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
474
475
476
477
478
479
480
481
482
483
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
484
485
486
487
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
488
        else:
Nicolas Patry's avatar
Nicolas Patry committed
489
490
491
492
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
493

Nicolas Patry's avatar
Nicolas Patry committed
494
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
539
540
541
542
543
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
544
545
546
547
548
549
550
551
552
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
553
554
555
556
557
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
558
559
560
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
561

drbh's avatar
drbh committed
562
563
564
565
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
566
            model_type = "mamba"
drbh's avatar
drbh committed
567
568
569
570
571
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

572
573
574
575
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
576
577
578
579
580
581

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
582

583
584
585
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
586
    )
587
588
589
590
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
591

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
607
                kv_cache_dtype=kv_cache_dtype,
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
627
628
629
630
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
631
            speculator=speculator,
drbh's avatar
drbh committed
632
633
634
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
635
636
637
638
    elif model_type == "ssm":
        raise RuntimeError(
            "`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf"
        )
639

OlivierDehaene's avatar
OlivierDehaene committed
640
    if model_id.startswith("facebook/galactica"):
641
642
643
644
645
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
646
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
647
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
648
649
            dtype=dtype,
            trust_remote_code=trust_remote_code,
650
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
651
652
        )

653
    if (
654
655
        model_type == GPT_BIGCODE
        or model_type == GPT2
656
657
        and model_id.startswith("bigcode/")
    ):
658
        if FLASH_ATTENTION:
659
660
661
662
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
663
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
                dtype=dtype,
666
                kv_cache_dtype=kv_cache_dtype,
667
                trust_remote_code=trust_remote_code,
668
669
670
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
671
            )
672
673
674
675
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
676
        else:
677
678
679
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
680
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
681
                speculator=speculator,
682
                dtype=dtype,
683
684
                trust_remote_code=trust_remote_code,
            )
685

686
    if model_type == BLOOM:
687
688
689
690
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
691
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
692
            speculator=speculator,
693
694
            dtype=dtype,
            trust_remote_code=trust_remote_code,
695
            batch_class=BloomCausalLMBatch,
696
        )
697
    elif model_type == MPT:
698
699
700
701
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
702
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
703
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
704
705
            dtype=dtype,
            trust_remote_code=trust_remote_code,
706
            batch_class=CausalLMBatchKeysLast,
707
        )
708
    elif model_type == GPT2:
709
        if FLASH_ATTENTION:
710
            try:
711
712
713
714
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
715
716
717
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
718
                    kv_cache_dtype=kv_cache_dtype,
719
                    trust_remote_code=trust_remote_code,
720
                    lora_adapter_ids=lora_adapter_ids,
721
722
723
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
724
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
725
                return CausalLM.fallback(
726
727
728
729
730
731
732
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
733
734
735
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
736
            return CausalLM.fallback(
737
738
739
740
741
742
743
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
744
745
746
747
748
749
750
751
752
753
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
754
                    kv_cache_dtype=kv_cache_dtype,
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
780
    elif model_type == GPT_NEOX:
781
        if FLASH_ATTENTION:
782
783
784
785
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

786
787
788
789
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
790
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
791
                speculator=speculator,
792
                dtype=dtype,
793
                kv_cache_dtype=kv_cache_dtype,
794
                trust_remote_code=trust_remote_code,
795
                lora_adapter_ids=lora_adapter_ids,
796
                config_class=GPTNeoXConfig,
797
798
            )
        elif sharded:
799
800
801
802
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
803
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
804
                speculator=speculator,
805
                dtype=dtype,
806
807
                trust_remote_code=trust_remote_code,
            )
808
        else:
809
            return CausalLM.fallback(
810
811
812
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
813
                speculator=speculator,
814
                dtype=dtype,
815
816
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
817

818
    elif model_type == PHI:
drbh's avatar
drbh committed
819
        if FLASH_ATTENTION:
820
821
822
823
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
824
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
825
                speculator=speculator,
drbh's avatar
drbh committed
826
                dtype=dtype,
827
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
828
                trust_remote_code=trust_remote_code,
829
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
830
831
            )
        else:
832
            return CausalLM.fallback(
drbh's avatar
drbh committed
833
834
835
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
836
                speculator=speculator,
drbh's avatar
drbh committed
837
838
839
840
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
841
842
843
844
845
846
847
848
849
850
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
851
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
852
853
854
855
856
857
858
859
860
861
862
863
864
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
865
866
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
867
868
869
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
870
        else:
871
872
873
874
875
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
876
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
877
                speculator=speculator,
drbh's avatar
drbh committed
878
879
880
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
881

882
883
884
885
886
887
    elif (
        model_type == LLAMA
        or model_type == BAICHUAN
        or model_type == PHI3
        or model_type == GRANITE
    ):
888
        if FLASH_ATTENTION:
889
890
891
892
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
893
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
894
                speculator=speculator,
895
                dtype=dtype,
896
                kv_cache_dtype=kv_cache_dtype,
897
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
898
                lora_adapter_ids=lora_adapter_ids,
899
            )
900
        elif sharded:
901
902
903
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
            )
904
        else:
905
            return CausalLM.fallback(
906
907
908
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
909
                speculator=speculator,
910
                dtype=dtype,
911
912
                trust_remote_code=trust_remote_code,
            )
913
    if model_type == GEMMA:
914
        if FLASH_ATTENTION:
915
916
917
918
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
919
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
920
                speculator=speculator,
921
                dtype=dtype,
922
                kv_cache_dtype=kv_cache_dtype,
923
924
                # Works better for these models
                default_dtype=torch.bfloat16,
925
                trust_remote_code=trust_remote_code,
926
                lora_adapter_ids=lora_adapter_ids,
927
928
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
929
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
930
        else:
931
            return CausalLM.fallback(
932
933
934
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
935
                speculator=speculator,
936
937
938
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
939
940
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
941
942
943
944
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
945
946
947
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
948
                kv_cache_dtype=kv_cache_dtype,
949
950
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
951
                trust_remote_code=trust_remote_code,
952
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
953
954
955
956
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
957
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
958
959
960
961
962
963
964
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
965

966
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
967
        if FLASH_ATTENTION:
968
969
970
971
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
972
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
973
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
974
                dtype=dtype,
975
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
976
                trust_remote_code=trust_remote_code,
977
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
978
979
980
981
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
982
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
983
984
985
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
986
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
987
988
989
990
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

991
    if model_type == DBRX:
992
        if FLASH_ATTENTION:
993
994
995
996
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
997
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
998
                speculator=speculator,
999
                dtype=dtype,
1000
                kv_cache_dtype=kv_cache_dtype,
1001
1002
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
1003
                trust_remote_code=trust_remote_code,
1004
1005
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
1006
1007
1008
1009
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
1010
            return CausalLM.fallback(
1011
1012
1013
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1014
                speculator=speculator,
1015
1016
1017
1018
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1019
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
1020
1021
        if sharded:
            if FLASH_ATTENTION:
1022
                if config_dict.get("alibi", False):
1023
                    raise NotImplementedError("sharded is not supported for this model")
1024
1025
1026
1027
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1028
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1029
                    speculator=speculator,
1030
                    dtype=dtype,
1031
                    kv_cache_dtype=kv_cache_dtype,
1032
1033
1034
1035
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1036
                    trust_remote_code=trust_remote_code,
1037
1038
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1039
                )
1040
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
1041
        else:
1042
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
1043
1044
1045
1046
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1047
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1048
                    speculator=speculator,
1049
                    dtype=dtype,
1050
                    kv_cache_dtype=kv_cache_dtype,
1051
1052
1053
1054
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1055
                    trust_remote_code=trust_remote_code,
1056
1057
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1058
1059
                )
            else:
1060
                return CausalLM.fallback(
1061
1062
1063
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1064
                    speculator=speculator,
1065
                    dtype=dtype,
1066
1067
1068
                    trust_remote_code=trust_remote_code,
                )

1069
    if model_type == MISTRAL:
1070
        if FLASH_ATTENTION:
1071
1072
1073
1074
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
1075
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1076
                speculator=speculator,
1077
                dtype=dtype,
1078
                kv_cache_dtype=kv_cache_dtype,
1079
                trust_remote_code=trust_remote_code,
1080
                lora_adapter_ids=lora_adapter_ids,
1081
            )
OlivierDehaene's avatar
OlivierDehaene committed
1082
1083
1084
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1085
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1086
1087
1088
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1089
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1090
1091
1092
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1093

1094
    if model_type == MIXTRAL:
1095
        if FLASH_ATTENTION:
1096
1097
1098
1099
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1100
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1101
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1102
                dtype=dtype,
1103
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1104
                trust_remote_code=trust_remote_code,
1105
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1106
            )
OlivierDehaene's avatar
OlivierDehaene committed
1107
1108
1109
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1110
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1111
1112
1113
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1114
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1115
1116
1117
1118
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1119
    if model_type == STARCODER2:
1120
        if FLASH_ATTENTION:
1121
1122
1123
1124
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1125
                quantize=quantize,
1126
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1127
                dtype=dtype,
1128
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1129
                trust_remote_code=trust_remote_code,
1130
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1131
1132
1133
1134
1135
1136
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1137
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1138
1139
1140
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1141
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1142
1143
1144
1145
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1146
    if model_type == QWEN2:
1147
        if FLASH_ATTENTION:
1148
1149
1150
1151
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1152
                quantize=quantize,
1153
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1154
                dtype=dtype,
1155
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1156
                trust_remote_code=trust_remote_code,
1157
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1158
1159
1160
1161
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1162
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1163
1164
1165
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1166
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1167
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1168
1169
                trust_remote_code=trust_remote_code,
            )
1170

1171
    if model_type == OPT:
1172
1173
1174
1175
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1176
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1177
            speculator=speculator,
1178
1179
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1180
        )
1181

1182
    if model_type == T5:
1183
1184
1185
1186
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1187
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1188
            speculator=speculator,
1189
            dtype=dtype,
1190
            trust_remote_code=trust_remote_code,
1191
1192
1193
1194
1195
1196
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1197
        )
1198
    if model_type == IDEFICS:
1199
        if FLASH_ATTENTION:
Nicolas Patry's avatar
Nicolas Patry committed
1200
            return IdeficsCausalLM(
OlivierDehaene's avatar
OlivierDehaene committed
1201
1202
1203
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1204
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1205
1206
1207
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1208
1209
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
    if model_type == QWEN2_VL:
        return VlmCausalLM(
            model_id=model_id,
            model_class=Qwen2VLForConditionalGeneration,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            kv_cache_dtype=kv_cache_dtype,
            trust_remote_code=trust_remote_code,
            lora_adapter_ids=lora_adapter_ids,
        )
Nicolas Patry's avatar
Nicolas Patry committed
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    if model_type == MLLAMA:
        if FLASH_ATTENTION:
            return MllamaCausalLM(
                model_id=model_id,
                model_class=MllamaForConditionalGeneration,
                batch_class=MllamaCausalLMBatch,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                default_dtype=torch.bfloat16,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
1238
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1239
        if FLASH_ATTENTION:
1240
1241
1242
1243
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1244
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1245
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1246
                dtype=dtype,
1247
                kv_cache_dtype=kv_cache_dtype,
Nicolas Patry's avatar
Nicolas Patry committed
1248
                trust_remote_code=trust_remote_code,
1249
1250
1251
1252
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1253
1254
1255
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1256
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1257
        if FLASH_ATTENTION:
1258
1259
1260
1261
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1262
1263
1264
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1265
                kv_cache_dtype=kv_cache_dtype,
1266
1267
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1268
                trust_remote_code=trust_remote_code,
1269
1270
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1271
1272
1273
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1274

1275
    if model_type == LLAVA_NEXT:
1276
        if FLASH_ATTENTION:
1277
1278
1279
1280
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1281
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1282
                speculator=speculator,
1283
                dtype=dtype,
1284
                kv_cache_dtype=kv_cache_dtype,
1285
1286
1287
1288
1289
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1290
    if sharded:
1291
        raise NotImplementedError("sharded is not supported for AutoModel")
1292
    if quantize == "gptq":
1293
        raise NotImplementedError(
1294
1295
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1296
    if quantize == "awq":
1297
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1298
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1299
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1300
    elif quantize == "eetq":
1301
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1302
1303
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1304
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1305
        return CausalLM.fallback(
1306
1307
1308
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1309
            speculator=speculator,
1310
1311
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1312
        )
1313
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1314
        return Seq2SeqLM.fallback(
1315
1316
1317
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1318
            speculator=speculator,
1319
1320
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1321
1322
        )

1323
    auto_map = config_dict.get("auto_map", None)
1324
1325
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1326
            return CausalLM.fallback(
1327
1328
1329
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1330
                speculator=speculator,
1331
                dtype=dtype,
1332
1333
                trust_remote_code=trust_remote_code,
            )
1334
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1335
            return Seq2SeqLM.fallback(
1336
1337
1338
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1339
                speculator=speculator,
1340
                dtype=dtype,
1341
1342
                trust_remote_code=trust_remote_code,
            )
1343
1344

    raise ValueError(f"Unsupported model type {model_type}")
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
1357
    kv_cache_dtype: Optional[str],
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
1371
        kv_cache_dtype,
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1424
                "qkv_proj",
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1452
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1453
1454
1455
1456
1457
1458
1459
1460
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model