__init__.py 30 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
9
from typing import Optional
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
14
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
15
from text_generation_server.models.flash_causal_lm import FlashCausalLM
16
from text_generation_server.models.bloom import BLOOMSharded
17
from text_generation_server.models.mpt import MPTSharded
18
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
19
from text_generation_server.models.rw import RW
20
21
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
22
23
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
24
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
25
from text_generation_server.models.phi import Phi
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

49
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
50

51
FLASH_ATTENTION = True
52

53
try:
54
    from text_generation_server.models.flash_rw import FlashRWSharded
55
    from text_generation_server.models.flash_gpt2 import FlashGPT2
56
57
58
59
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
63
64
65
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
66
67
68
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
drbh's avatar
drbh committed
69
70
71
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
72
73
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
74
    )
75
    from text_generation_server.models.idefics import IDEFICSSharded
76
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
77
    from text_generation_server.models.idefics2 import Idefics2
78
79
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
80
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
81
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
82
    from text_generation_server.models.flash_dbrx import FlashDbrx
fxmarty's avatar
fxmarty committed
83
84
85
86
    from text_generation_server.utils.flash_attn import (
        HAS_FLASH_ATTN_V2_CUDA,
        HAS_FLASH_ATTN_V2_ROCM,
    )
87
88
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
89
    FLASH_ATTENTION = False
90
    HAS_FLASH_ATTN_V2_CUDA = False
fxmarty's avatar
fxmarty committed
91
    HAS_FLASH_ATTN_V2_ROCM = False
92

93
if FLASH_ATTENTION:
94
    __all__.append(FlashGPT2)
95
    __all__.append(FlashNeoXSharded)
96
    __all__.append(FlashRWSharded)
97
    __all__.append(FlashSantacoderSharded)
98
    __all__.append(FlashLlama)
99
    __all__.append(IDEFICSSharded)
100
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
101
    __all__.append(FlashMixtral)
102
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
103
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
104
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
105
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
106
107
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
108

drbh's avatar
drbh committed
109
110
111
112
113
114
115
116
117
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
118

119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


256
def get_model(
257
258
259
260
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
261
    speculate: Optional[int],
262
    dtype: Optional[str],
263
    trust_remote_code: bool,
264
) -> Model:
265
    if dtype is None:
266
        if quantize in ["awq", "exl2", "gptq"]:
267
268
269
270
271
272
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
273
274
275
276
277
278
279
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
280
281
282
283
284
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
285
286
287
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
288
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
289

Nicolas Patry's avatar
Nicolas Patry committed
290
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
291
    if "medusa_num_heads" in config_dict:
292
293
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
294
295
296
297
298
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
299
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
300
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
301
                )
Nicolas Patry's avatar
Nicolas Patry committed
302
303
304
305
306
307
308
309
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
310
311
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
312
313
314
315
316
317
318
319
320
321
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
322
323
324
325
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
326
        else:
Nicolas Patry's avatar
Nicolas Patry committed
327
328
329
330
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
331

Nicolas Patry's avatar
Nicolas Patry committed
332
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
386
387
388
389
390
391
392
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
393
394
395
396
397
398
399
400
401
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
402
403
404
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
405
        if method in {"gptq", "awq", "exl2"}:
406
407
408
409
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
410

411
412
413
414
415
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )

416
    if model_type == MAMBA:
drbh's avatar
drbh committed
417
418
419
420
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
421
            speculator=speculator,
drbh's avatar
drbh committed
422
423
424
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
425

OlivierDehaene's avatar
OlivierDehaene committed
426
427
428
429
430
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
431
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
432
433
434
435
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

436
    if (
437
438
        model_type == GPT_BIGCODE
        or model_type == GPT2
439
440
        and model_id.startswith("bigcode/")
    ):
441
        if FLASH_ATTENTION:
442
443
444
445
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
446
                speculator=speculator,
447
                dtype=dtype,
448
449
                trust_remote_code=trust_remote_code,
            )
450
451
452
453
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
454
        else:
455
            return SantaCoder(
456
457
458
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
459
                speculator=speculator,
460
                dtype=dtype,
461
462
                trust_remote_code=trust_remote_code,
            )
463

464
    if model_type == BLOOM:
465
        return BLOOMSharded(
466
467
468
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
469
            speculator=speculator,
470
471
            dtype=dtype,
            trust_remote_code=trust_remote_code,
472
        )
473
    elif model_type == MPT:
474
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
475
476
477
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
478
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
479
480
            dtype=dtype,
            trust_remote_code=trust_remote_code,
481
        )
482
    elif model_type == GPT2:
483
        if FLASH_ATTENTION:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            try:
                return FlashGPT2(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
                return CausalLM(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
504
505
506
507
508
509
510
511
512
513
514
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
515
    elif model_type == GPT_NEOX:
516
517
518
519
520
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
521
                speculator=speculator,
522
                dtype=dtype,
523
524
525
526
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
527
528
529
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
530
                speculator=speculator,
531
                dtype=dtype,
532
533
                trust_remote_code=trust_remote_code,
            )
534
        else:
535
            return CausalLM(
536
537
538
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
539
                speculator=speculator,
540
                dtype=dtype,
541
542
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
543

544
    elif model_type == PHI:
drbh's avatar
drbh committed
545
546
547
548
549
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
550
                speculator=speculator,
drbh's avatar
drbh committed
551
552
553
554
555
556
557
558
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
559
                speculator=speculator,
drbh's avatar
drbh committed
560
561
562
563
564
565
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
566
567
568
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
569
570
571
572
573
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
574
                speculator=speculator,
drbh's avatar
drbh committed
575
576
577
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
578

579
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
580
581
        if FLASH_ATTENTION:
            return FlashLlama(
582
583
584
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
585
                speculator=speculator,
586
                dtype=dtype,
587
588
                trust_remote_code=trust_remote_code,
            )
589
590
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
591
        else:
592
            return CausalLM(
593
594
595
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
596
                speculator=speculator,
597
                dtype=dtype,
598
599
                trust_remote_code=trust_remote_code,
            )
600
    if model_type == GEMMA:
601
602
603
604
605
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
606
                speculator=speculator,
607
608
609
610
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
611
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
612
613
614
615
616
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
617
                speculator=speculator,
618
619
620
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
621

622
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
623
624
625
626
627
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
628
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
629
630
631
632
633
634
635
636
637
638
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
640
641
642
643
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

644
    if model_type == DBRX:
645
646
647
648
649
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
650
                speculator=speculator,
651
652
653
654
655
656
657
658
659
660
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
661
                speculator=speculator,
662
663
664
665
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

666
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
667
668
        if sharded:
            if FLASH_ATTENTION:
669
                if config_dict.get("alibi", False):
670
671
672
673
674
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
675
                    speculator=speculator,
676
                    dtype=dtype,
677
678
                    trust_remote_code=trust_remote_code,
                )
679
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
680
        else:
681
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
682
                return FlashRWSharded(
683
684
685
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
686
                    speculator=speculator,
687
                    dtype=dtype,
688
689
690
691
692
693
694
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
695
                    speculator=speculator,
696
                    dtype=dtype,
697
698
699
                    trust_remote_code=trust_remote_code,
                )

700
    if model_type == MISTRAL:
701
702
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
703
704
705
706
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
707
708
709
710
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
711
                speculator=speculator,
712
713
714
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
715
716
717
718
719
720
721
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
722
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
723
724
725
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
726

727
    if model_type == MIXTRAL:
728
729
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
730
731
732
733
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
734
735
736
737
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
738
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
739
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
740
741
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
742
743
744
745
746
747
748
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
749
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
750
751
752
753
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

754
    if model_type == STARCODER2:
OlivierDehaene's avatar
OlivierDehaene committed
755
756
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
757
758
759
760
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
761
762
763
764
765
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
766
767
768
769
770
771
772
773
774
775
776
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
777
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
778
779
780
781
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

782
    if model_type == QWEN2:
OlivierDehaene's avatar
OlivierDehaene committed
783
784
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
785
786
787
788
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
789
790
791
792
793
794
795
796
797
798
799
800
801
802
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
803
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
804
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
805
806
                trust_remote_code=trust_remote_code,
            )
807

808
    if model_type == OPT:
809
        return OPTSharded(
810
811
812
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
813
            speculator=speculator,
814
815
            dtype=dtype,
            trust_remote_code=trust_remote_code,
816
        )
817

818
    if model_type == T5:
819
820
821
822
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
823
            speculator=speculator,
824
            dtype=dtype,
825
826
            trust_remote_code=trust_remote_code,
        )
827
    if model_type == IDEFICS:
828
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
829
830
831
832
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
833
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
834
835
836
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
837
838
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
839
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
840
841
842
843
844
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
845
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
846
847
848
849
850
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
851
852
853
854
855
856
857
858
859
860
861
862
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
863

864
    if model_type == LLAVA_NEXT:
865
866
867
868
869
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
870
                speculator=speculator,
871
872
873
874
875
876
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

877
    if sharded:
878
        raise NotImplementedError("sharded is not supported for AutoModel")
879
    if quantize == "gptq":
880
        raise NotImplementedError(
881
882
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
883
    if quantize == "awq":
884
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
885
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
886
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
887
    elif quantize == "eetq":
888
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
889
890
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
891
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
892
        return CausalLM(
893
894
895
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
896
            speculator=speculator,
897
898
            dtype=dtype,
            trust_remote_code=trust_remote_code,
899
        )
900
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
901
        return Seq2SeqLM(
902
903
904
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
905
            speculator=speculator,
906
907
            dtype=dtype,
            trust_remote_code=trust_remote_code,
908
909
        )

910
    auto_map = config_dict.get("auto_map", None)
911
912
913
914
915
916
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
917
                speculator=speculator,
918
                dtype=dtype,
919
920
                trust_remote_code=trust_remote_code,
            )
921
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
922
923
924
925
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
926
                speculator=speculator,
927
                dtype=dtype,
928
929
                trust_remote_code=trust_remote_code,
            )
930
931

    raise ValueError(f"Unsupported model type {model_type}")