__init__.py 50.7 KB
Newer Older
1
2
3
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

4
import torch
5
import enum
Nicolas Patry's avatar
Nicolas Patry committed
6
import os
7

8
from loguru import logger
9
from transformers.configuration_utils import PretrainedConfig
10
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
11
from huggingface_hub import hf_hub_download, HfApi
12
from typing import Optional, List, Dict
13
from pathlib import Path
14

Nicolas Patry's avatar
Nicolas Patry committed
15
from text_generation_server.utils.speculate import get_speculate, set_speculate
16
from text_generation_server.models.model import Model
17
18
19
20
21
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
22
from text_generation_server.models.bloom import BloomCausalLMBatch
23
24
25
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
26
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
drbh's avatar
drbh committed
35
36
37
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
    PhiMoEConfig,
)
38
39
40
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
41

42
43
44
45
46
47
48
49
50
51

from text_generation_server.utils.adapter import (
    AdapterParameters,
    build_layer_weight_lookup,
    load_and_merge_adapters,
    AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


52
from text_generation_server.utils.import_utils import SYSTEM
53
from text_generation_server.utils.log import log_master
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "CausalLM",
    "Seq2SeqLM",
69
    "get_model_with_lora_adapters",
70
71
]

72
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
73

74
FLASH_ATTENTION = True
75

76
try:
77
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
78
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
Nicolas Patry's avatar
Nicolas Patry committed
79
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
80
81
82
83
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
84
85
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
86
    )
87
88
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
89
    )
90
91
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
92
    )
93
94
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
95
    )
96
97
98
99
100
101
102
103
104
105
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
106
    )
drbh's avatar
drbh committed
107
    from text_generation_server.models.pali_gemma import (
108
        PaliGemmaBatch,
drbh's avatar
drbh committed
109
    )
110
111
112
113
114
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
115
    )
Nicolas Patry's avatar
Nicolas Patry committed
116
117
118
119
120
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
    from text_generation_server.models.custom_modeling.mllama import (
        MllamaForConditionalGeneration,
    )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
143
144
145
    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
        FlashGPTJForCausalLM,
    )
146
147
148
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
drbh's avatar
drbh committed
149
150
151
    from text_generation_server.models.custom_modeling.qwen2_vl import (
        Qwen2VLForConditionalGeneration,
    )
152
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
153
except ImportError as e:
154
    log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
155
    SUPPORTS_WINDOWING = False
156
    FLASH_ATTENTION = False
157

158
if FLASH_ATTENTION:
159
    __all__.append(FlashCausalLM)
Nicolas Patry's avatar
Nicolas Patry committed
160
    __all__.append(IdeficsCausalLM)
OlivierDehaene's avatar
OlivierDehaene committed
161

drbh's avatar
drbh committed
162
163
164
165
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
166
    log_master(logger.warning, f"Could not import Mamba: {e}")
drbh's avatar
drbh committed
167
168
169
170
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
171

172

173
class ModelType(enum.Enum):
174
175
176
177
178
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
194
        "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
195
196
197
198
199
200
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
201
202
203
204
205
    GRANITE = {
        "type": "granite",
        "name": "Granite",
        "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
    }
206
207
208
209
210
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
211
212
213
214
215
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
216
217
218
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
219
        "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
Nicolas Patry's avatar
Nicolas Patry committed
220
    }
221
222
223
224
225
226
227
228
229
230
231
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
232
        "type": "mamba",
233
234
235
236
237
238
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
239
        "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
drbh's avatar
drbh committed
256
257
258
259
260
    PHI_MOE = {
        "type": "phimoe",
        "name": "PhiMoe",
        "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
    }
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
279
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
280
    }
drbh's avatar
drbh committed
281
282
283
284
285
    QWEN2_VL = {
        "type": "qwen2_vl",
        "name": "Qwen 2 VL",
        "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
    }
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
326
327
328
329
330
    GPTJ = {
        "type": "gptj",
        "name": "Gptj",
        "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
    }
331
332
333
334
335
336
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }
Nicolas Patry's avatar
Nicolas Patry committed
337
338
339
340
341
342
    MLLAMA = {
        "type": "mllama",
        "name": "Mllama",
        "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
        "multimodal": True,
    }
343
344
345
346
347
348
349


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


350
def get_model(
351
    model_id: str,
drbh's avatar
drbh committed
352
    lora_adapter_ids: Optional[List[str]],
353
354
355
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
356
    speculate: Optional[int],
357
    dtype: Optional[str],
358
    kv_cache_dtype: Optional[str],
359
    trust_remote_code: bool,
360
    max_input_tokens: int,
361
) -> Model:
362
    global FLASH_ATTENTION
363
364
365
366
367
368
369

    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
    model_type = config_dict.get("model_type", None)

    quantization_config = config_dict.get("quantization_config", None)
370
    compression_config = config_dict.get("compression_config", None)
371
372
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
373
        config_groups = quantization_config.get("config_groups", None)
374
375
376
        if method in {"gptq", "awq", "exl2"}:
            log_master(logger.info, f"Auto selecting quantization method {method}")
            quantize = method
377
        elif method == "fbgemm_fp8" or method == "fp8":
378
379
            log_master(logger.info, "Auto selecting quantization method fp8")
            quantize = "fp8"
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        elif config_groups is not None:
            # TODO: at some point we should probably fully parse the compression
            # configuration to know which parameters are compressed.
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
395
396
        else:
            log_master(logger.warning, f"Unknown quantization method {method}")
397
    elif compression_config is not None:
398
        # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        config_groups = compression_config.get("config_groups")
        if config_groups is not None:
            for _, group in config_groups.items():
                weights_config = group.get("weights")
                if weights_config is not None:
                    if (
                        weights_config["type"] == "float"
                        and weights_config["num_bits"] == 8
                    ):
                        log_master(
                            logger.info, "Auto selecting quantization method fp8"
                        )
                        quantize = "fp8"
                        break
413

414
    if dtype is None:
415
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
Nicolas Patry's avatar
Nicolas Patry committed
416
417
418
419
420
            if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
                dtype = torch.bfloat16
            else:
                # These quantizers only work with float16 params.
                dtype = torch.float16
421
422
423
424
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
425
426
427
428
429
430
431
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

432
433
    if kv_cache_dtype is None:
        kv_cache_dtype = dtype
434
435
    elif kv_cache_dtype == "fp8_e4m3fn":
        kv_cache_dtype = torch.float8_e4m3fn
436
437
438
439
440
    elif kv_cache_dtype == "fp8_e5m2":
        kv_cache_dtype = torch.float8_e5m2
    else:
        raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
441
442
443
444
445
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

Nicolas Patry's avatar
Nicolas Patry committed
446
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
447
    if "medusa_num_heads" in config_dict:
448
449
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
450
451
452
453
454
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
455
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
456
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
457
                )
Nicolas Patry's avatar
Nicolas Patry committed
458
459
460
461
462
463
464
465
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
466
467
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
468
469
470
471
472
473
474
475
476
477
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
478
479
480
481
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
482
        else:
Nicolas Patry's avatar
Nicolas Patry committed
483
484
485
486
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
487

Nicolas Patry's avatar
Nicolas Patry committed
488
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
533
534
535
536
537
            speculator_dir_path = Path(mlp_speculator_config).parent
            # if these are downloaded, they get converted to safetensors
            filenames.extend(
                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
            )
Nicolas Patry's avatar
Nicolas Patry committed
538
539
540
541
542
543
544
545
546
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
547
548
549
550
551
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
552
553
554
        log_master(
            logger.info, f"Using speculation {method} with {speculate} input ids."
        )
Nicolas Patry's avatar
Nicolas Patry committed
555

drbh's avatar
drbh committed
556
557
558
559
560
561
562
563
564
565
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )

566
567
568
569
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
drbh's avatar
drbh committed
570
571
572
573
574
575

    sliding_window = (
        config_dict.get("sliding_window")
        if config_dict.get("sliding_window") is not None
        else -1
    )
576

577
578
579
    use_sliding_window = sliding_window is not None and sliding_window != -1
    needs_sliding_window = (
        max_input_tokens is not None and max_input_tokens > sliding_window
580
    )
581
582
583
584
    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
        )
585

586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
601
                kv_cache_dtype=kv_cache_dtype,
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
621
622
623
624
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
625
            speculator=speculator,
drbh's avatar
drbh committed
626
627
628
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
629
630
631
632
    elif model_type == "ssm":
        raise RuntimeError(
            "`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf"
        )
633

OlivierDehaene's avatar
OlivierDehaene committed
634
    if model_id.startswith("facebook/galactica"):
635
636
637
638
639
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
640
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
641
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
642
643
            dtype=dtype,
            trust_remote_code=trust_remote_code,
644
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
645
646
        )

647
    if (
648
649
        model_type == GPT_BIGCODE
        or model_type == GPT2
650
651
        and model_id.startswith("bigcode/")
    ):
652
        if FLASH_ATTENTION:
653
654
655
656
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
657
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
658
                speculator=speculator,
659
                dtype=dtype,
660
                kv_cache_dtype=kv_cache_dtype,
661
                trust_remote_code=trust_remote_code,
662
663
664
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
665
            )
666
667
668
669
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
670
        else:
671
672
673
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
674
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
675
                speculator=speculator,
676
                dtype=dtype,
677
678
                trust_remote_code=trust_remote_code,
            )
679

680
    if model_type == BLOOM:
681
682
683
684
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
685
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
686
            speculator=speculator,
687
688
            dtype=dtype,
            trust_remote_code=trust_remote_code,
689
            batch_class=BloomCausalLMBatch,
690
        )
691
    elif model_type == MPT:
692
693
694
695
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
696
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
697
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
698
699
            dtype=dtype,
            trust_remote_code=trust_remote_code,
700
            batch_class=CausalLMBatchKeysLast,
701
        )
702
    elif model_type == GPT2:
703
        if FLASH_ATTENTION:
704
            try:
705
706
707
708
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
709
710
711
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
712
                    kv_cache_dtype=kv_cache_dtype,
713
                    trust_remote_code=trust_remote_code,
714
                    lora_adapter_ids=lora_adapter_ids,
715
716
717
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
718
                log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
719
                return CausalLM.fallback(
720
721
722
723
724
725
726
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
727
728
729
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
730
            return CausalLM.fallback(
731
732
733
734
735
736
737
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
738
739
740
741
742
743
744
745
746
747
    elif model_type == GPTJ:
        if FLASH_ATTENTION:
            try:
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPTJForCausalLM,
                    revision=revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
748
                    kv_cache_dtype=kv_cache_dtype,
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
                    trust_remote_code=trust_remote_code,
                    lora_adapter_ids=lora_adapter_ids,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
                return CausalLM.fallback(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
774
    elif model_type == GPT_NEOX:
775
        if FLASH_ATTENTION:
776
777
778
779
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

780
781
782
783
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
784
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
785
                speculator=speculator,
786
                dtype=dtype,
787
                kv_cache_dtype=kv_cache_dtype,
788
                trust_remote_code=trust_remote_code,
789
                lora_adapter_ids=lora_adapter_ids,
790
                config_class=GPTNeoXConfig,
791
792
            )
        elif sharded:
793
794
795
796
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
797
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
798
                speculator=speculator,
799
                dtype=dtype,
800
801
                trust_remote_code=trust_remote_code,
            )
802
        else:
803
            return CausalLM.fallback(
804
805
806
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
807
                speculator=speculator,
808
                dtype=dtype,
809
810
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
811

812
    elif model_type == PHI:
drbh's avatar
drbh committed
813
        if FLASH_ATTENTION:
814
815
816
817
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
818
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
819
                speculator=speculator,
drbh's avatar
drbh committed
820
                dtype=dtype,
821
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
822
                trust_remote_code=trust_remote_code,
823
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
824
825
            )
        else:
826
            return CausalLM.fallback(
drbh's avatar
drbh committed
827
828
829
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
830
                speculator=speculator,
drbh's avatar
drbh committed
831
832
833
834
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
835
836
837
838
839
840
841
842
843
844
    elif model_type == PHI_MOE:
        if FLASH_ATTENTION:
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                config_class=PhiMoEConfig,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
845
                kv_cache_dtype=kv_cache_dtype,
drbh's avatar
drbh committed
846
847
848
849
850
851
852
853
854
855
856
857
858
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

drbh's avatar
drbh committed
859
860
    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
861
862
863
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
864
        else:
865
866
867
868
869
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
870
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
871
                speculator=speculator,
drbh's avatar
drbh committed
872
873
874
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
875

876
877
878
879
880
881
    elif (
        model_type == LLAMA
        or model_type == BAICHUAN
        or model_type == PHI3
        or model_type == GRANITE
    ):
882
        if FLASH_ATTENTION:
883
884
885
886
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
887
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
888
                speculator=speculator,
889
                dtype=dtype,
890
                kv_cache_dtype=kv_cache_dtype,
891
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
892
                lora_adapter_ids=lora_adapter_ids,
893
            )
894
        elif sharded:
895
896
897
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
            )
898
        else:
899
            return CausalLM.fallback(
900
901
902
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
903
                speculator=speculator,
904
                dtype=dtype,
905
906
                trust_remote_code=trust_remote_code,
            )
907
    if model_type == GEMMA:
908
        if FLASH_ATTENTION:
909
910
911
912
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
913
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
914
                speculator=speculator,
915
                dtype=dtype,
916
                kv_cache_dtype=kv_cache_dtype,
917
918
                # Works better for these models
                default_dtype=torch.bfloat16,
919
                trust_remote_code=trust_remote_code,
920
                lora_adapter_ids=lora_adapter_ids,
921
922
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
923
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
924
        else:
925
            return CausalLM.fallback(
926
927
928
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
929
                speculator=speculator,
930
931
932
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
933
934
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
935
936
937
938
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
939
940
941
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
942
                kv_cache_dtype=kv_cache_dtype,
943
944
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
945
                trust_remote_code=trust_remote_code,
946
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
947
948
949
950
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
951
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
952
953
954
955
956
957
958
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
959

960
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
961
        if FLASH_ATTENTION:
962
963
964
965
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
966
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
967
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
968
                dtype=dtype,
969
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
970
                trust_remote_code=trust_remote_code,
971
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
972
973
974
975
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
976
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
977
978
979
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
980
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
981
982
983
984
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

985
    if model_type == DBRX:
986
        if FLASH_ATTENTION:
987
988
989
990
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
991
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
992
                speculator=speculator,
993
                dtype=dtype,
994
                kv_cache_dtype=kv_cache_dtype,
995
996
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
997
                trust_remote_code=trust_remote_code,
998
999
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
1000
1001
1002
1003
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
1004
            return CausalLM.fallback(
1005
1006
1007
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1008
                speculator=speculator,
1009
1010
1011
1012
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1013
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
1014
1015
        if sharded:
            if FLASH_ATTENTION:
1016
                if config_dict.get("alibi", False):
1017
                    raise NotImplementedError("sharded is not supported for this model")
1018
1019
1020
1021
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1022
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1023
                    speculator=speculator,
1024
                    dtype=dtype,
1025
                    kv_cache_dtype=kv_cache_dtype,
1026
1027
1028
1029
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1030
                    trust_remote_code=trust_remote_code,
1031
1032
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1033
                )
1034
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
1035
        else:
1036
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
1037
1038
1039
1040
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
1041
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1042
                    speculator=speculator,
1043
                    dtype=dtype,
1044
                    kv_cache_dtype=kv_cache_dtype,
1045
1046
1047
1048
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
1049
                    trust_remote_code=trust_remote_code,
1050
1051
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
1052
1053
                )
            else:
1054
                return CausalLM.fallback(
1055
1056
1057
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1058
                    speculator=speculator,
1059
                    dtype=dtype,
1060
1061
1062
                    trust_remote_code=trust_remote_code,
                )

1063
    if model_type == MISTRAL:
1064
        if FLASH_ATTENTION:
1065
1066
1067
1068
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
1069
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1070
                speculator=speculator,
1071
                dtype=dtype,
1072
                kv_cache_dtype=kv_cache_dtype,
1073
                trust_remote_code=trust_remote_code,
1074
                lora_adapter_ids=lora_adapter_ids,
1075
            )
OlivierDehaene's avatar
OlivierDehaene committed
1076
1077
1078
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
1079
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1080
1081
1082
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1083
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1084
1085
1086
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
1087

1088
    if model_type == MIXTRAL:
1089
        if FLASH_ATTENTION:
1090
1091
1092
1093
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1094
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1095
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1096
                dtype=dtype,
1097
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1098
                trust_remote_code=trust_remote_code,
1099
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1100
            )
OlivierDehaene's avatar
OlivierDehaene committed
1101
1102
1103
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
1104
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1105
1106
1107
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1108
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1109
1110
1111
1112
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1113
    if model_type == STARCODER2:
1114
        if FLASH_ATTENTION:
1115
1116
1117
1118
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1119
                quantize=quantize,
1120
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1121
                dtype=dtype,
1122
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1123
                trust_remote_code=trust_remote_code,
1124
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1125
1126
1127
1128
1129
1130
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
1131
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1132
1133
1134
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1135
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1136
1137
1138
1139
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

1140
    if model_type == QWEN2:
1141
        if FLASH_ATTENTION:
1142
1143
1144
1145
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
1146
                quantize=quantize,
1147
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1148
                dtype=dtype,
1149
                kv_cache_dtype=kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1150
                trust_remote_code=trust_remote_code,
1151
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1152
1153
1154
1155
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
1156
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
1157
1158
1159
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1160
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1161
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
1162
1163
                trust_remote_code=trust_remote_code,
            )
1164

1165
    if model_type == OPT:
1166
1167
1168
1169
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
1170
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1171
            speculator=speculator,
1172
1173
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1174
        )
1175

1176
    if model_type == T5:
1177
1178
1179
1180
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
1181
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1182
            speculator=speculator,
1183
            dtype=dtype,
1184
            trust_remote_code=trust_remote_code,
1185
1186
1187
1188
1189
1190
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
1191
        )
1192
    if model_type == IDEFICS:
1193
        if FLASH_ATTENTION:
Nicolas Patry's avatar
Nicolas Patry committed
1194
            return IdeficsCausalLM(
OlivierDehaene's avatar
OlivierDehaene committed
1195
1196
1197
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1198
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
1199
1200
1201
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1202
1203
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
    if model_type == QWEN2_VL:
        return VlmCausalLM(
            model_id=model_id,
            model_class=Qwen2VLForConditionalGeneration,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            kv_cache_dtype=kv_cache_dtype,
            trust_remote_code=trust_remote_code,
            lora_adapter_ids=lora_adapter_ids,
        )
Nicolas Patry's avatar
Nicolas Patry committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    if model_type == MLLAMA:
        if FLASH_ATTENTION:
            return MllamaCausalLM(
                model_id=model_id,
                model_class=MllamaForConditionalGeneration,
                batch_class=MllamaCausalLMBatch,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                default_dtype=torch.bfloat16,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
1232
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1233
        if FLASH_ATTENTION:
1234
1235
1236
1237
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1238
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1239
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1240
                dtype=dtype,
1241
                kv_cache_dtype=kv_cache_dtype,
Nicolas Patry's avatar
Nicolas Patry committed
1242
                trust_remote_code=trust_remote_code,
1243
1244
1245
1246
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1247
1248
1249
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1250
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1251
        if FLASH_ATTENTION:
1252
1253
1254
1255
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1256
1257
1258
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1259
                kv_cache_dtype=kv_cache_dtype,
1260
1261
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1262
                trust_remote_code=trust_remote_code,
1263
1264
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1265
1266
1267
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1268

1269
    if model_type == LLAVA_NEXT:
1270
        if FLASH_ATTENTION:
1271
1272
1273
1274
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1275
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1276
                speculator=speculator,
1277
                dtype=dtype,
1278
                kv_cache_dtype=kv_cache_dtype,
1279
1280
1281
1282
1283
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1284
    if sharded:
1285
        raise NotImplementedError("sharded is not supported for AutoModel")
1286
    if quantize == "gptq":
1287
        raise NotImplementedError(
1288
1289
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1290
    if quantize == "awq":
1291
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1292
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1293
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1294
    elif quantize == "eetq":
1295
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1296
1297
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1298
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1299
        return CausalLM.fallback(
1300
1301
1302
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1303
            speculator=speculator,
1304
1305
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1306
        )
1307
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1308
        return Seq2SeqLM.fallback(
1309
1310
1311
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1312
            speculator=speculator,
1313
1314
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1315
1316
        )

1317
    auto_map = config_dict.get("auto_map", None)
1318
1319
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1320
            return CausalLM.fallback(
1321
1322
1323
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1324
                speculator=speculator,
1325
                dtype=dtype,
1326
1327
                trust_remote_code=trust_remote_code,
            )
1328
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1329
            return Seq2SeqLM.fallback(
1330
1331
1332
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1333
                speculator=speculator,
1334
                dtype=dtype,
1335
1336
                trust_remote_code=trust_remote_code,
            )
1337
1338

    raise ValueError(f"Unsupported model type {model_type}")
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350


# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
    model_id: str,
    lora_adapters: Optional[List[AdapterInfo]],
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
1351
    kv_cache_dtype: Optional[str],
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
    trust_remote_code: bool,
    max_input_tokens: int,
    adapter_to_index: Dict[str, int],
):
    lora_adapter_ids = [adapter.id for adapter in lora_adapters]
    model = get_model(
        model_id,
        lora_adapter_ids,
        revision,
        sharded,
        quantize,
        speculate,
        dtype,
1365
        kv_cache_dtype,
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        trust_remote_code,
        max_input_tokens,
    )

    if len(lora_adapters) > 0:
        target_to_layer = build_layer_weight_lookup(model.model)

        for index, adapter in enumerate(lora_adapters):
            # The AdapterParameters object allows for merging multiple adapters into a single adapter.
            # At the moment, we only support loading a single adapter into the model, but we keep the
            # AdapterParameters object for easier extension in the future.
            adapter_parameters = AdapterParameters(
                adapter_info=[adapter],
                # when merging multiple adapters we can weight them differently
                # if this is not set, all adapters will be weighted equally
                # see: text_generation_server.utils.merges.strategies for impl
                weights=None,
                merge_strategy=0,
                density=1.0,
                majority_sign_method=0,
            )

            adapter_index = index + 1
            adapter_to_index[adapter.id] = adapter_index

            logger.info(
                f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
            )
            weight_names = tuple([v[0] for v in target_to_layer.values()])
            (
                module_map,
                adapter_config,
                adapter_weight_names,
                adapter_tokenizer,
            ) = load_and_merge_adapters(
                model.model_id,
                adapter_parameters,
                adapter_index,
                weight_names,
                False,
            )

            unused_weight_names = adapter_weight_names.copy()

            adapter_layers = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
1418
                "qkv_proj",
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
            ]

            for layer_name in adapter_layers:
                nlayers = (
                    1 if layer_name == "lm_head" else len(model.model.model.layers)
                )
                adapter_weights = LoraWeights.prepare_weights(
                    config=adapter_config,
                    module_map=module_map,
                    layer_type=layer_name,
                    unused_weight_names=unused_weight_names,
                    nlayers=nlayers,
                    dtype=model.dtype,
                    world_size=model.world_size,
                    process_group=model.process_group,
                    target_to_layer=target_to_layer,
                )

                if adapter_weights is None:
                    continue

                model.layer_to_adapter_weights[layer_name].add_adapter(
                    adapter_index, adapter_weights
                )

            if len(unused_weight_names) > 0:
                logger.warning(
1446
                    f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
1447
1448
1449
1450
1451
1452
1453
1454
                )

            if adapter_tokenizer is not None:
                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

            model.loaded_adapters.add(adapter_index)

    return model