".github/vscode:/vscode.git/clone" did not exist on "111ffc426be7cbb9efcf485053c9237a94f58fa3"
__init__.py 37.7 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
drbh's avatar
drbh committed
9
from typing import Optional, List
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
from text_generation_server.models.model import Model
14
15
16
17
18
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
    MPTForCausalLM,
)
19
from text_generation_server.models.bloom import BloomCausalLMBatch
20
21
22
from text_generation_server.models.custom_modeling.bloom_modeling import (
    BloomForCausalLM,
)
23
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
24
25
26
27
28
29
30
31
32
33
34
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
    GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
    PhiConfig,
    PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
    T5ForConditionalGeneration,
)
35

36
37
from text_generation_server.utils.import_utils import SYSTEM

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "get_model",
]

57
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
58

59
FLASH_ATTENTION = True
60

61
try:
62
    from text_generation_server.models.flash_causal_lm import FlashCausalLM
63
    from text_generation_server.models.vlm_causal_lm import VlmCausalLM
64
65
66
67
    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
        FlashDeepseekV2ForCausalLM,
        DeepseekV2Config,
    )
68
69
    from text_generation_server.models.custom_modeling.flash_llama_modeling import (
        FlashLlamaForCausalLM,
70
    )
71
72
    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
        FlashCohereForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
73
    )
74
75
    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
        FlashGemmaForCausalLM,
OlivierDehaene's avatar
OlivierDehaene committed
76
    )
77
78
    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
        FlashGemma2ForCausalLM,
79
    )
80
81
82
83
84
85
86
87
88
89
    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
        FlashDbrxForCausalLM,
        DbrxConfig,
    )
    from text_generation_server.models.custom_modeling.flash_rw_modeling import (
        RWConfig,
        FlashRWForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_neox_modeling import (
        FlashGPTNeoXForCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
90
    )
drbh's avatar
drbh committed
91
    from text_generation_server.models.pali_gemma import (
92
        PaliGemmaBatch,
drbh's avatar
drbh committed
93
    )
94
95
96
97
98
    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
        PaliGemmaForConditionalGeneration,
    )
    from text_generation_server.models.custom_modeling.flash_phi_modeling import (
        FlashPhiForCausalLM,
99
    )
100
    from text_generation_server.models.idefics import IDEFICSSharded
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    from text_generation_server.models.custom_modeling.llava_next import (
        LlavaNextForConditionalGeneration,
    )

    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
        FlashSantacoderForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
        FlashStarcoder2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
        Qwen2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
        FlashMistralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
        FlashMixtralForCausalLM,
    )
    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
        FlashGPT2ForCausalLM,
    )
    from text_generation_server.models.custom_modeling.idefics2 import (
        Idefics2ForConditionalGeneration,
    )
126
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
127
128
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
129
    SUPPORTS_WINDOWING = False
130
    FLASH_ATTENTION = False
131

132
if FLASH_ATTENTION:
133
    __all__.append(FlashCausalLM)
134
    __all__.append(IDEFICSSharded)
OlivierDehaene's avatar
OlivierDehaene committed
135

drbh's avatar
drbh committed
136
137
138
139
140
141
142
143
144
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
145

146

147
class ModelType(enum.Enum):
148
149
150
151
152
    DEEPSEEK_V2 = {
        "type": "deepseek_v2",
        "name": "Deepseek V2",
        "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
    }
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
180
181
182
183
184
    PALIGEMMA = {
        "type": "paligemma",
        "name": "PaliGemma",
        "url": "https://huggingface.co/google/paligemma-3b-pt-224",
    }
Nicolas Patry's avatar
Nicolas Patry committed
185
186
187
188
189
    GEMMA2 = {
        "type": "gemma2",
        "name": "Gemma2",
        "url": "https://huggingface.co/google/gemma2-9b",
    }
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
243
        "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


298
def get_model(
299
    model_id: str,
drbh's avatar
drbh committed
300
    lora_adapter_ids: Optional[List[str]],
301
302
303
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
304
    speculate: Optional[int],
305
    dtype: Optional[str],
306
    trust_remote_code: bool,
307
    max_input_tokens: int,
308
) -> Model:
309
    global FLASH_ATTENTION
310
    if dtype is None:
311
        if quantize in ["awq", "exl2", "gptq", "marlin"]:
312
313
314
315
316
317
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
318
319
320
321
322
323
324
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
325
326
327
328
329
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
330
331
332
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
333
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
334

Nicolas Patry's avatar
Nicolas Patry committed
335
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
336
    if "medusa_num_heads" in config_dict:
337
338
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
339
340
341
342
343
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
344
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
345
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
346
                )
Nicolas Patry's avatar
Nicolas Patry committed
347
348
349
350
351
352
353
354
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
355
356
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
357
358
359
360
361
362
363
364
365
366
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
367
368
369
370
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
371
        else:
Nicolas Patry's avatar
Nicolas Patry committed
372
373
374
375
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
376

Nicolas Patry's avatar
Nicolas Patry committed
377
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
431
432
433
434
435
436
437
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
438
439
440
441
442
443
444
445
446
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
447
448
449
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
450
        if method in {"gptq", "awq", "exl2"}:
451
452
453
454
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
455

456
457
458
459
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
460
    sliding_window = config_dict.get("sliding_window", -1)
461
462
463
464
465
466
467
468

    if (
        (sliding_window is not None and sliding_window != -1)
        and not SUPPORTS_WINDOWING
        and max_input_tokens > sliding_window
    ):
        raise ValueError(
            f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
469
        )
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    if model_type == DEEPSEEK_V2:
        if FLASH_ATTENTION:
            head_size = max(
                config_dict.get("qk_nope_dim", 128)
                + config_dict.get("qk_rope_dim", 64),
                config_dict.get("v_head_dim", 128),
            )
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDeepseekV2ForCausalLM,
                revision=revision,
                quantize=quantize,
                speculator=speculator,
                default_dtype=torch.bfloat16,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
                lora_adapter_ids=lora_adapter_ids,
                config_class=DeepseekV2Config,
                head_size=head_size,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
            )
        else:
            return CausalLM.fallback(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
    elif model_type == MAMBA:
drbh's avatar
drbh committed
505
506
507
508
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
509
            speculator=speculator,
drbh's avatar
drbh committed
510
511
512
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
513

OlivierDehaene's avatar
OlivierDehaene committed
514
    if model_id.startswith("facebook/galactica"):
515
516
517
518
519
        return CausalLM(
            model_id=model_id,
            # Yes galactica is just an OPT model.
            model_class=OPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
520
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
521
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
522
523
            dtype=dtype,
            trust_remote_code=trust_remote_code,
524
            batch_class=GalacticaCausalLMBatch,
OlivierDehaene's avatar
OlivierDehaene committed
525
526
        )

527
    if (
528
529
        model_type == GPT_BIGCODE
        or model_type == GPT2
530
531
        and model_id.startswith("bigcode/")
    ):
532
        if FLASH_ATTENTION:
533
534
535
536
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashSantacoderForCausalLM,
                revision=revision,
537
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
538
                speculator=speculator,
539
                dtype=dtype,
540
                trust_remote_code=trust_remote_code,
541
542
543
                lora_adapter_ids=lora_adapter_ids,
                aliases={"transformer.wte.weight": ["lm_head.weight"]},
                num_kv_heads=1,
544
            )
545
546
547
548
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
549
        else:
550
551
552
            return CausalLM.fallback(
                model_id=model_id,
                revision=revision,
553
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
554
                speculator=speculator,
555
                dtype=dtype,
556
557
                trust_remote_code=trust_remote_code,
            )
558

559
    if model_type == BLOOM:
560
561
562
563
        return CausalLM(
            model_id=model_id,
            model_class=BloomForCausalLM,
            revision=revision,
564
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
565
            speculator=speculator,
566
567
            dtype=dtype,
            trust_remote_code=trust_remote_code,
568
            batch_class=BloomCausalLMBatch,
569
        )
570
    elif model_type == MPT:
571
572
573
574
        return CausalLM(
            model_id=model_id,
            model_class=MPTForCausalLM,
            revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
575
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
576
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
577
578
            dtype=dtype,
            trust_remote_code=trust_remote_code,
579
            batch_class=CausalLMBatchKeysLast,
580
        )
581
    elif model_type == GPT2:
582
        if FLASH_ATTENTION:
583
            try:
584
585
586
587
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashGPT2ForCausalLM,
                    revision=revision,
588
589
590
591
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
592
                    lora_adapter_ids=lora_adapter_ids,
593
594
595
596
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
597
                return CausalLM.fallback(
598
599
600
601
602
603
604
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
605
606
607
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
608
            return CausalLM.fallback(
609
610
611
612
613
614
615
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
616
    elif model_type == GPT_NEOX:
617
        if FLASH_ATTENTION:
618
619
620
621
            from text_generation_server.models.custom_modeling.flash_neox_modeling import (
                GPTNeoXConfig,
            )

622
623
624
625
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGPTNeoXForCausalLM,
                revision=revision,
626
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
627
                speculator=speculator,
628
                dtype=dtype,
629
                trust_remote_code=trust_remote_code,
630
                lora_adapter_ids=lora_adapter_ids,
631
                config_class=GPTNeoXConfig,
632
633
            )
        elif sharded:
634
635
636
637
            return CausalLM(
                model_id=model_id,
                model_class=GPTNeoxForCausalLM,
                revision=revision,
638
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
640
                dtype=dtype,
641
642
                trust_remote_code=trust_remote_code,
            )
643
        else:
644
            return CausalLM.fallback(
645
646
647
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
648
                speculator=speculator,
649
                dtype=dtype,
650
651
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
652

653
    elif model_type == PHI:
drbh's avatar
drbh committed
654
        if FLASH_ATTENTION:
655
656
657
658
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashPhiForCausalLM,
                revision=revision,
drbh's avatar
drbh committed
659
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
660
                speculator=speculator,
drbh's avatar
drbh committed
661
662
                dtype=dtype,
                trust_remote_code=trust_remote_code,
663
                lora_adapter_ids=lora_adapter_ids,
drbh's avatar
drbh committed
664
665
            )
        else:
666
            return CausalLM.fallback(
drbh's avatar
drbh committed
667
668
669
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
670
                speculator=speculator,
drbh's avatar
drbh committed
671
672
673
674
675
676
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
677
678
679
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
680
        else:
681
682
683
684
685
            return CausalLM(
                model_id=model_id,
                model_class=PhiForCausalLM,
                config_class=PhiConfig,
                revision=revision,
drbh's avatar
drbh committed
686
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
687
                speculator=speculator,
drbh's avatar
drbh committed
688
689
690
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
691

692
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
693
        if FLASH_ATTENTION:
694
695
696
697
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashLlamaForCausalLM,
                revision=revision,
698
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
699
                speculator=speculator,
700
                dtype=dtype,
701
                trust_remote_code=trust_remote_code,
drbh's avatar
drbh committed
702
                lora_adapter_ids=lora_adapter_ids,
703
            )
704
705
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
706
        else:
707
            return CausalLM.fallback(
708
709
710
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
711
                speculator=speculator,
712
                dtype=dtype,
713
714
                trust_remote_code=trust_remote_code,
            )
715
    if model_type == GEMMA:
716
        if FLASH_ATTENTION:
717
718
719
720
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemmaForCausalLM,
                revision=revision,
721
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
722
                speculator=speculator,
723
                dtype=dtype,
724
725
                # Works better for these models
                default_dtype=torch.bfloat16,
726
                trust_remote_code=trust_remote_code,
727
                lora_adapter_ids=lora_adapter_ids,
728
729
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
730
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
731
        else:
732
            return CausalLM.fallback(
733
734
735
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
736
                speculator=speculator,
737
738
739
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
Nicolas Patry's avatar
Nicolas Patry committed
740
741
    elif model_type == GEMMA2:
        if FLASH_ATTENTION:
742
743
744
745
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashGemma2ForCausalLM,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
746
747
748
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
749
750
                # Works better for these models
                default_dtype=torch.bfloat16,
Nicolas Patry's avatar
Nicolas Patry committed
751
                trust_remote_code=trust_remote_code,
752
                lora_adapter_ids=lora_adapter_ids,
Nicolas Patry's avatar
Nicolas Patry committed
753
754
755
756
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
        else:
757
            return CausalLM.fallback(
Nicolas Patry's avatar
Nicolas Patry committed
758
759
760
761
762
763
764
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
765

766
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
767
        if FLASH_ATTENTION:
768
769
770
771
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashCohereForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
772
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
773
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
774
775
                dtype=dtype,
                trust_remote_code=trust_remote_code,
776
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
777
778
779
780
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
781
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
782
783
784
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
785
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
786
787
788
789
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

790
    if model_type == DBRX:
791
        if FLASH_ATTENTION:
792
793
794
795
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashDbrxForCausalLM,
                revision=revision,
796
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
797
                speculator=speculator,
798
                dtype=dtype,
799
800
                # Dbrx works better in bfloat16.
                default_dtype=torch.bfloat16,
801
                trust_remote_code=trust_remote_code,
802
803
                lora_adapter_ids=lora_adapter_ids,
                config_class=DbrxConfig,
804
805
806
807
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
808
            return CausalLM.fallback(
809
810
811
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
812
                speculator=speculator,
813
814
815
816
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

817
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
818
819
        if sharded:
            if FLASH_ATTENTION:
820
                if config_dict.get("alibi", False):
821
                    raise NotImplementedError("sharded is not supported for this model")
822
823
824
825
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
826
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
827
                    speculator=speculator,
828
                    dtype=dtype,
829
830
831
832
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
833
                    trust_remote_code=trust_remote_code,
834
835
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
836
                )
837
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
838
        else:
839
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
840
841
842
843
                return FlashCausalLM(
                    model_id=model_id,
                    model_class=FlashRWForCausalLM,
                    revision=revision,
844
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
845
                    speculator=speculator,
846
                    dtype=dtype,
847
848
849
850
                    aliases={
                        "lm_head.weight": ["transformer.word_embeddings.weight"],
                        "transformer.word_embeddings.weight": ["lm_head.weight"],
                    },
851
                    trust_remote_code=trust_remote_code,
852
853
                    lora_adapter_ids=lora_adapter_ids,
                    config_class=RWConfig,
854
855
                )
            else:
856
                return CausalLM.fallback(
857
858
859
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
860
                    speculator=speculator,
861
                    dtype=dtype,
862
863
864
                    trust_remote_code=trust_remote_code,
                )

865
    if model_type == MISTRAL:
866
        if FLASH_ATTENTION:
867
868
869
870
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMistralForCausalLM,
                revision=revision,
871
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
872
                speculator=speculator,
873
874
                dtype=dtype,
                trust_remote_code=trust_remote_code,
875
                lora_adapter_ids=lora_adapter_ids,
876
            )
OlivierDehaene's avatar
OlivierDehaene committed
877
878
879
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
880
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
881
882
883
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
884
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
885
886
887
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
888

889
    if model_type == MIXTRAL:
890
        if FLASH_ATTENTION:
891
892
893
894
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashMixtralForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
895
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
896
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
897
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
898
                trust_remote_code=trust_remote_code,
899
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
900
            )
OlivierDehaene's avatar
OlivierDehaene committed
901
902
903
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
904
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
905
906
907
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
908
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
909
910
911
912
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

913
    if model_type == STARCODER2:
914
        if FLASH_ATTENTION:
915
916
917
918
            return FlashCausalLM(
                model_id=model_id,
                model_class=FlashStarcoder2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
919
                quantize=quantize,
920
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
921
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
922
                trust_remote_code=trust_remote_code,
923
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
924
925
926
927
928
929
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
930
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
931
932
933
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
934
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
935
936
937
938
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

939
    if model_type == QWEN2:
940
        if FLASH_ATTENTION:
941
942
943
944
            return FlashCausalLM(
                model_id=model_id,
                model_class=Qwen2ForCausalLM,
                revision=revision,
OlivierDehaene's avatar
OlivierDehaene committed
945
                quantize=quantize,
946
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
947
948
                dtype=dtype,
                trust_remote_code=trust_remote_code,
949
                lora_adapter_ids=lora_adapter_ids,
OlivierDehaene's avatar
OlivierDehaene committed
950
951
952
953
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
954
            return CausalLM.fallback(
OlivierDehaene's avatar
OlivierDehaene committed
955
956
957
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
958
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
959
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
960
961
                trust_remote_code=trust_remote_code,
            )
962

963
    if model_type == OPT:
964
965
966
967
        return CausalLM(
            model_id=model_id,
            model_class=OPTForCausalLM,
            revision=revision,
968
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
969
            speculator=speculator,
970
971
            dtype=dtype,
            trust_remote_code=trust_remote_code,
972
        )
973

974
    if model_type == T5:
975
976
977
978
        return Seq2SeqLM(
            model_id=model_id,
            model_class=T5ForConditionalGeneration,
            revision=revision,
979
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
980
            speculator=speculator,
981
            dtype=dtype,
982
            trust_remote_code=trust_remote_code,
983
984
985
986
987
988
            aliases={
                "shared.weight": [
                    "encoder.embed_tokens.weight",
                    "decoder.embed_tokens.weight",
                ]
            },
989
        )
990
    if model_type == IDEFICS:
991
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
992
993
994
995
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
996
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
997
998
999
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
1000
1001
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1002
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
1003
        if FLASH_ATTENTION:
1004
1005
1006
1007
            return VlmCausalLM(
                model_id=model_id,
                model_class=Idefics2ForConditionalGeneration,
                revision=revision,
Nicolas Patry's avatar
Nicolas Patry committed
1008
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1009
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
1010
1011
                dtype=dtype,
                trust_remote_code=trust_remote_code,
1012
1013
1014
1015
                lora_adapter_ids=lora_adapter_ids,
                # XXX: Extremely important to cap resolution in order to limit
                # VRAM usage.
                processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
Nicolas Patry's avatar
Nicolas Patry committed
1016
1017
1018
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1019
    if model_type == PALIGEMMA:
drbh's avatar
drbh committed
1020
        if FLASH_ATTENTION:
1021
1022
1023
1024
            return VlmCausalLM(
                model_id=model_id,
                model_class=PaliGemmaForConditionalGeneration,
                revision=revision,
drbh's avatar
drbh committed
1025
1026
1027
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
1028
1029
                # Works better for these models
                default_dtype=torch.bfloat16,
drbh's avatar
drbh committed
1030
                trust_remote_code=trust_remote_code,
1031
1032
                lora_adapter_ids=lora_adapter_ids,
                batch_class=PaliGemmaBatch,
drbh's avatar
drbh committed
1033
1034
1035
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
1036

1037
    if model_type == LLAVA_NEXT:
1038
        if FLASH_ATTENTION:
1039
1040
1041
1042
            return VlmCausalLM(
                model_class=LlavaNextForConditionalGeneration,
                model_id=model_id,
                revision=revision,
1043
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1044
                speculator=speculator,
1045
1046
1047
1048
1049
1050
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

1051
    if sharded:
1052
        raise NotImplementedError("sharded is not supported for AutoModel")
1053
    if quantize == "gptq":
1054
        raise NotImplementedError(
1055
1056
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
1057
    if quantize == "awq":
1058
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
1059
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
1060
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
1061
    elif quantize == "eetq":
1062
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
1063
1064
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
1065
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1066
        return CausalLM.fallback(
1067
1068
1069
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1070
            speculator=speculator,
1071
1072
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1073
        )
1074
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
1075
        return Seq2SeqLM.fallback(
1076
1077
1078
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1079
            speculator=speculator,
1080
1081
            dtype=dtype,
            trust_remote_code=trust_remote_code,
1082
1083
        )

1084
    auto_map = config_dict.get("auto_map", None)
1085
1086
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
1087
            return CausalLM.fallback(
1088
1089
1090
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1091
                speculator=speculator,
1092
                dtype=dtype,
1093
1094
                trust_remote_code=trust_remote_code,
            )
1095
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
1096
            return Seq2SeqLM.fallback(
1097
1098
1099
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1100
                speculator=speculator,
1101
                dtype=dtype,
1102
1103
                trust_remote_code=trust_remote_code,
            )
1104
1105

    raise ValueError(f"Unsupported model type {model_type}")