__init__.py 29.7 KB
Newer Older
1
import torch
2
import enum
Nicolas Patry's avatar
Nicolas Patry committed
3
import os
4

5
from loguru import logger
6
from transformers.configuration_utils import PretrainedConfig
7
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
8
from huggingface_hub import hf_hub_download, HfApi
9
from typing import Optional
10
from pathlib import Path
11

Nicolas Patry's avatar
Nicolas Patry committed
12
from text_generation_server.utils.speculate import get_speculate, set_speculate
13
14
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
15
from text_generation_server.models.flash_causal_lm import FlashCausalLM
16
from text_generation_server.models.bloom import BLOOMSharded
17
from text_generation_server.models.mpt import MPTSharded
18
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
19
from text_generation_server.models.rw import RW
20
21
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
22
23
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
24
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
25
from text_generation_server.models.phi import Phi
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

49
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
50

51
FLASH_ATTENTION = True
52

53
try:
54
    from text_generation_server.models.flash_rw import FlashRWSharded
55
    from text_generation_server.models.flash_gpt2 import FlashGPT2
56
57
58
59
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
63
64
65
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
66
67
68
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
drbh's avatar
drbh committed
69
70
71
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
72
73
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
74
    )
75
    from text_generation_server.models.idefics import IDEFICSSharded
76
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
77
    from text_generation_server.models.idefics2 import Idefics2
78
79
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
80
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
81
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
82
    from text_generation_server.models.flash_dbrx import FlashDbrx
83
    from text_generation_server.layers.attention import SUPPORTS_WINDOWING
84
85
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
86
    SUPPORTS_WINDOWING = False
87
    FLASH_ATTENTION = False
88

89
if FLASH_ATTENTION:
90
    __all__.append(FlashGPT2)
91
    __all__.append(FlashNeoXSharded)
92
    __all__.append(FlashRWSharded)
93
    __all__.append(FlashSantacoderSharded)
94
    __all__.append(FlashLlama)
95
    __all__.append(IDEFICSSharded)
96
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
97
    __all__.append(FlashMixtral)
98
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
99
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
100
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
101
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
102
103
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
104

drbh's avatar
drbh committed
105
106
107
108
109
110
111
112
113
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
114

115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class ModelType(enum.Enum):
    IDEFICS2 = {
        "type": "idefics2",
        "name": "Idefics 2",
        "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
        "multimodal": True,
    }
    LLAVA_NEXT = {
        "type": "llava_next",
        "name": "Llava Next (1.6)",
        "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
        "multimodal": True,
    }
    LLAMA = {
        "type": "llama",
        "name": "Llama",
        "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
    }
    PHI3 = {
        "type": "phi3",
        "name": "Phi 3",
        "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
    }
    GEMMA = {
        "type": "gemma",
        "name": "Gemma",
        "url": "https://huggingface.co/google/gemma-7b",
    }
    COHERE = {
        "type": "cohere",
        "name": "Cohere",
        "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
    }
    DBRX = {
        "type": "dbrx",
        "name": "Dbrx",
        "url": "https://huggingface.co/databricks/dbrx-instruct",
    }
    MAMBA = {
        "type": "ssm",
        "name": "Mamba",
        "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
    }
    MISTRAL = {
        "type": "mistral",
        "name": "Mistral",
        "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
    }
    MIXTRAL = {
        "type": "mixtral",
        "name": "Mixtral",
        "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
    }
    GPT_BIGCODE = {
        "type": "gpt_bigcode",
        "name": "Gpt Bigcode",
        "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
    }
    PHI = {
        "type": "phi",
        "name": "Phi",
        "url": "https://huggingface.co/microsoft/phi-1_5",
    }
    BAICHUAN = {
        "type": "baichuan",
        "name": "Baichuan",
        "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
    }
    FALCON = {
        "type": "falcon",
        "name": "Falcon",
        "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
    }
    STARCODER2 = {
        "type": "starcoder2",
        "name": "StarCoder 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    QWEN2 = {
        "type": "qwen2",
        "name": "Qwen 2",
        "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
    }
    OPT = {
        "type": "opt",
        "name": "Opt",
        "url": "https://huggingface.co/facebook/opt-6.7b",
    }
    T5 = {
        "type": "t5",
        "name": "T5",
        "url": "https://huggingface.co/google/flan-t5-xxl",
    }
    GALACTICA = {
        "type": "galactica",
        "name": "Galactica",
        "url": "https://huggingface.co/facebook/galactica-120b",
    }
    SANTACODER = {
        "type": "santacoder",
        "name": "SantaCoder",
        "url": "https://huggingface.co/bigcode/santacoder",
    }
    BLOOM = {
        "type": "bloom",
        "name": "Bloom",
        "url": "https://huggingface.co/bigscience/bloom-560m",
    }
    MPT = {
        "type": "mpt",
        "name": "Mpt",
        "url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
    }
    GPT2 = {
        "type": "gpt2",
        "name": "Gpt2",
        "url": "https://huggingface.co/openai-community/gpt2",
    }
    GPT_NEOX = {
        "type": "gpt_neox",
        "name": "Gpt Neox",
        "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
    }
    IDEFICS = {
        "type": "idefics",
        "name": "Idefics",
        "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
        "multimodal": True,
    }


__GLOBALS = locals()
for data in ModelType:
    __GLOBALS[data.name] = data.value["type"]


252
def get_model(
253
254
255
256
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
257
    speculate: Optional[int],
258
    dtype: Optional[str],
259
    trust_remote_code: bool,
260
) -> Model:
261
    global FLASH_ATTENTION
262
    if dtype is None:
263
        if quantize in ["awq", "exl2", "gptq"]:
264
265
266
267
268
269
            # These quantizers only work with float16 params.
            dtype = torch.float16
        else:
            # Keep it as default for now and let
            # every model resolve their own default dtype.
            dtype = None
270
271
272
273
274
275
276
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
277
278
279
280
281
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
282
283
284
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
285
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
286

Nicolas Patry's avatar
Nicolas Patry committed
287
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
288
    if "medusa_num_heads" in config_dict:
289
290
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
291
292
293
294
295
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
296
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
297
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
298
                )
Nicolas Patry's avatar
Nicolas Patry committed
299
300
301
302
303
304
305
306
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
307
308
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
309
310
311
312
313
314
315
316
317
318
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
319
320
321
322
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
323
        else:
Nicolas Patry's avatar
Nicolas Patry committed
324
325
326
327
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
328

Nicolas Patry's avatar
Nicolas Patry committed
329
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
383
384
385
386
387
388
389
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
390
391
392
393
394
395
396
397
398
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
399
400
401
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
402
        if method in {"gptq", "awq", "exl2"}:
403
404
405
406
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
407

408
409
410
411
    if quantize == "exl2" and sharded:
        raise RuntimeError(
            "Sharding is currently not supported with `exl2` quantization"
        )
412
413
414
415
416
417
    sliding_window = config_dict.get("sliding_window", -1)
    if sliding_window != -1 and not SUPPORTS_WINDOWING:
        logger.warning(
            f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
        )
        FLASH_ATTENTION = False
418

419
    if model_type == MAMBA:
drbh's avatar
drbh committed
420
421
422
423
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
424
            speculator=speculator,
drbh's avatar
drbh committed
425
426
427
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
428

OlivierDehaene's avatar
OlivierDehaene committed
429
430
431
432
433
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
434
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
435
436
437
438
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

439
    if (
440
441
        model_type == GPT_BIGCODE
        or model_type == GPT2
442
443
        and model_id.startswith("bigcode/")
    ):
444
        if FLASH_ATTENTION:
445
446
447
448
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
449
                speculator=speculator,
450
                dtype=dtype,
451
452
                trust_remote_code=trust_remote_code,
            )
453
454
455
456
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
457
        else:
458
            return SantaCoder(
459
460
461
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
462
                speculator=speculator,
463
                dtype=dtype,
464
465
                trust_remote_code=trust_remote_code,
            )
466

467
    if model_type == BLOOM:
468
        return BLOOMSharded(
469
470
471
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
472
            speculator=speculator,
473
474
            dtype=dtype,
            trust_remote_code=trust_remote_code,
475
        )
476
    elif model_type == MPT:
477
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
478
479
480
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
481
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
482
483
            dtype=dtype,
            trust_remote_code=trust_remote_code,
484
        )
485
    elif model_type == GPT2:
486
        if FLASH_ATTENTION:
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            try:
                return FlashGPT2(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
            except RuntimeError as e:
                # Lots of legacy models with various weight names.
                logger.warning(f"Couldn't load flash gpt2 variant: {e}")
                return CausalLM(
                    model_id,
                    revision,
                    quantize=quantize,
                    speculator=speculator,
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                )
507
508
509
510
511
512
513
514
515
516
517
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
518
    elif model_type == GPT_NEOX:
519
520
521
522
523
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
524
                speculator=speculator,
525
                dtype=dtype,
526
527
528
529
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
530
531
532
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
533
                speculator=speculator,
534
                dtype=dtype,
535
536
                trust_remote_code=trust_remote_code,
            )
537
        else:
538
            return CausalLM(
539
540
541
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
542
                speculator=speculator,
543
                dtype=dtype,
544
545
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
546

547
    elif model_type == PHI:
drbh's avatar
drbh committed
548
549
550
551
552
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
553
                speculator=speculator,
drbh's avatar
drbh committed
554
555
556
557
558
559
560
561
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
562
                speculator=speculator,
drbh's avatar
drbh committed
563
564
565
566
567
568
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
569
570
571
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
572
573
574
575
576
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
577
                speculator=speculator,
drbh's avatar
drbh committed
578
579
580
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
581

582
    elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
583
584
        if FLASH_ATTENTION:
            return FlashLlama(
585
586
587
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
588
                speculator=speculator,
589
                dtype=dtype,
590
591
                trust_remote_code=trust_remote_code,
            )
592
593
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
594
        else:
595
            return CausalLM(
596
597
598
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
599
                speculator=speculator,
600
                dtype=dtype,
601
602
                trust_remote_code=trust_remote_code,
            )
603
    if model_type == GEMMA:
604
605
606
607
608
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
609
                speculator=speculator,
610
611
612
613
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
614
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
615
616
617
618
619
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
620
                speculator=speculator,
621
622
623
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
624

625
    if model_type == COHERE:
OlivierDehaene's avatar
OlivierDehaene committed
626
627
628
629
630
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
631
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
632
633
634
635
636
637
638
639
640
641
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
642
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
643
644
645
646
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

647
    if model_type == DBRX:
648
649
650
651
652
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
653
                speculator=speculator,
654
655
656
657
658
659
660
661
662
663
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
666
667
668
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

669
    if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
670
671
        if sharded:
            if FLASH_ATTENTION:
672
                if config_dict.get("alibi", False):
673
674
675
676
677
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
678
                    speculator=speculator,
679
                    dtype=dtype,
680
681
                    trust_remote_code=trust_remote_code,
                )
682
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
683
        else:
684
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
685
                return FlashRWSharded(
686
687
688
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
689
                    speculator=speculator,
690
                    dtype=dtype,
691
692
693
694
695
696
697
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
698
                    speculator=speculator,
699
                    dtype=dtype,
700
701
702
                    trust_remote_code=trust_remote_code,
                )

703
    if model_type == MISTRAL:
704
        sliding_window = config_dict.get("sliding_window", -1)
705
        if FLASH_ATTENTION:
706
707
708
709
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
710
                speculator=speculator,
711
712
713
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
714
715
716
717
718
719
720
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
721
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
722
723
724
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
725

726
    if model_type == MIXTRAL:
727
        sliding_window = config_dict.get("sliding_window", -1)
728
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
729
730
731
732
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
733
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
734
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
735
736
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
737
738
739
740
741
742
743
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
744
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
745
746
747
748
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

749
    if model_type == STARCODER2:
OlivierDehaene's avatar
OlivierDehaene committed
750
        sliding_window = config_dict.get("sliding_window", -1)
751
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
752
753
754
755
756
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
757
758
759
760
761
762
763
764
765
766
767
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
768
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
769
770
771
772
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

773
    if model_type == QWEN2:
OlivierDehaene's avatar
OlivierDehaene committed
774
        sliding_window = config_dict.get("sliding_window", -1)
775
        if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
OlivierDehaene's avatar
OlivierDehaene committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
790
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
791
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
792
793
                trust_remote_code=trust_remote_code,
            )
794

795
    if model_type == OPT:
796
        return OPTSharded(
797
798
799
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
800
            speculator=speculator,
801
802
            dtype=dtype,
            trust_remote_code=trust_remote_code,
803
        )
804

805
    if model_type == T5:
806
807
808
809
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
810
            speculator=speculator,
811
            dtype=dtype,
812
813
            trust_remote_code=trust_remote_code,
        )
814
    if model_type == IDEFICS:
815
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
816
817
818
819
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
820
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
821
822
823
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
824
825
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
826
    if model_type == IDEFICS2:
Nicolas Patry's avatar
Nicolas Patry committed
827
828
829
830
831
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
832
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
833
834
835
836
837
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
838
839
840
841
842
843
844
845
846
847
848
849
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
850

851
    if model_type == LLAVA_NEXT:
852
853
854
855
856
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
857
                speculator=speculator,
858
859
860
861
862
863
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

864
    if sharded:
865
        raise NotImplementedError("sharded is not supported for AutoModel")
866
    if quantize == "gptq":
867
        raise NotImplementedError(
868
869
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
870
    if quantize == "awq":
871
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
872
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
873
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
874
    elif quantize == "eetq":
875
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
876
877
    elif quantize == "exl2":
        raise NotImplementedError("exl2 quantization is not supported for AutoModel")
878
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
879
        return CausalLM(
880
881
882
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
883
            speculator=speculator,
884
885
            dtype=dtype,
            trust_remote_code=trust_remote_code,
886
        )
887
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
888
        return Seq2SeqLM(
889
890
891
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
892
            speculator=speculator,
893
894
            dtype=dtype,
            trust_remote_code=trust_remote_code,
895
896
        )

897
    auto_map = config_dict.get("auto_map", None)
898
899
900
901
902
903
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
904
                speculator=speculator,
905
                dtype=dtype,
906
907
                trust_remote_code=trust_remote_code,
            )
908
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
909
910
911
912
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
913
                speculator=speculator,
914
                dtype=dtype,
915
916
                trust_remote_code=trust_remote_code,
            )
917
918

    raise ValueError(f"Unsupported model type {model_type}")