__init__.py 24.6 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
7
from huggingface_hub import hf_hub_download, HfApi
8
from typing import Optional
9
from pathlib import Path
10

Nicolas Patry's avatar
Nicolas Patry committed
11
from text_generation_server.utils.speculate import get_speculate, set_speculate
12
13
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
14
from text_generation_server.models.flash_causal_lm import FlashCausalLM
15
from text_generation_server.models.bloom import BLOOMSharded
16
from text_generation_server.models.mpt import MPTSharded
17
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
18
from text_generation_server.models.rw import RW
19
20
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
21
22
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
23
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
24
from text_generation_server.models.phi import Phi
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51

52
try:
53
    from text_generation_server.models.flash_rw import FlashRWSharded
54
    from text_generation_server.models.flash_gpt2 import FlashGPT2
55
56
57
58
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
59
60
61
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
65
66
67
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
68
69
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
70
    )
71
    from text_generation_server.models.idefics import IDEFICSSharded
72
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
73
    from text_generation_server.models.idefics2 import Idefics2
74
75
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
76
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
77
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
78
    from text_generation_server.models.flash_dbrx import FlashDbrx
79
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
80
81
82

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
83
    FLASH_ATTENTION = False
84
    HAS_FLASH_ATTN_V2_CUDA = False
85

86
if FLASH_ATTENTION:
87
    __all__.append(FlashGPT2)
88
    __all__.append(FlashNeoXSharded)
89
    __all__.append(FlashRWSharded)
90
    __all__.append(FlashSantacoderSharded)
91
    __all__.append(FlashLlama)
92
    __all__.append(IDEFICSSharded)
93
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
94
    __all__.append(FlashMixtral)
95
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
96
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
97
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
98
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
99
100
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
101

drbh's avatar
drbh committed
102
103
104
105
106
107
108
109
110
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
111

112

113
def get_model(
114
115
116
117
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
118
    speculate: Optional[int],
119
    dtype: Optional[str],
120
    trust_remote_code: bool,
121
) -> Model:
122
    if dtype is None:
123
124
125
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
126
127
128
129
130
131
132
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
133
134
135
136
137
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
138
139
140
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
141
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
142

Nicolas Patry's avatar
Nicolas Patry committed
143
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
144
    if "medusa_num_heads" in config_dict:
145
146
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
147
148
149
150
151
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
152
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
153
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
154
                )
Nicolas Patry's avatar
Nicolas Patry committed
155
156
157
158
159
160
161
162
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
163
164
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
165
166
167
168
169
170
171
172
173
174
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
175
176
177
178
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
179
        else:
Nicolas Patry's avatar
Nicolas Patry committed
180
181
182
183
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
184

Nicolas Patry's avatar
Nicolas Patry committed
185
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
239
240
241
242
243
244
245
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
246
247
248
249
250
251
252
253
254
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
255
256
257
258
259
260
261
262
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
263
264
265
266
267
268

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
269
            speculator=speculator,
drbh's avatar
drbh committed
270
271
272
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
273

OlivierDehaene's avatar
OlivierDehaene committed
274
275
276
277
278
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
279
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
280
281
282
283
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

284
285
286
287
288
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
289
        if FLASH_ATTENTION:
290
291
292
293
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
294
                speculator=speculator,
295
                dtype=dtype,
296
297
                trust_remote_code=trust_remote_code,
            )
298
299
300
301
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
302
        else:
303
            return SantaCoder(
304
305
306
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
307
                speculator=speculator,
308
                dtype=dtype,
309
310
                trust_remote_code=trust_remote_code,
            )
311

312
    if model_type == "bloom":
313
        return BLOOMSharded(
314
315
316
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
317
            speculator=speculator,
318
319
            dtype=dtype,
            trust_remote_code=trust_remote_code,
320
        )
321
322
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
323
324
325
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
326
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
327
328
            dtype=dtype,
            trust_remote_code=trust_remote_code,
329
        )
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    elif model_type == "gpt2":
        if FLASH_ATTENTION:
            return FlashGPT2(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
351
352
353
354
355
356
    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
357
                speculator=speculator,
358
                dtype=dtype,
359
360
361
362
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
363
364
365
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
366
                speculator=speculator,
367
                dtype=dtype,
368
369
                trust_remote_code=trust_remote_code,
            )
370
        else:
371
            return CausalLM(
372
373
374
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
375
                speculator=speculator,
376
                dtype=dtype,
377
378
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
379

drbh's avatar
drbh committed
380
381
382
383
384
385
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
386
                speculator=speculator,
drbh's avatar
drbh committed
387
388
389
390
391
392
393
394
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
395
                speculator=speculator,
drbh's avatar
drbh committed
396
397
398
399
400
401
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
402
403
404
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
405
406
407
408
409
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
410
                speculator=speculator,
drbh's avatar
drbh committed
411
412
413
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
414

Nicolas Patry's avatar
Nicolas Patry committed
415
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
416
417
        if FLASH_ATTENTION:
            return FlashLlama(
418
419
420
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
421
                speculator=speculator,
422
                dtype=dtype,
423
424
                trust_remote_code=trust_remote_code,
            )
425
426
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
427
        else:
428
            return CausalLM(
429
430
431
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
432
                speculator=speculator,
433
                dtype=dtype,
434
435
                trust_remote_code=trust_remote_code,
            )
436
437
438
439
440
441
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
442
                speculator=speculator,
443
444
445
446
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
447
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
448
449
450
451
452
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
453
                speculator=speculator,
454
455
456
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
457

OlivierDehaene's avatar
OlivierDehaene committed
458
459
460
461
462
463
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
464
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
465
466
467
468
469
470
471
472
473
474
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
475
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
476
477
478
479
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

480
481
482
483
484
485
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
486
                speculator=speculator,
487
488
489
490
491
492
493
494
495
496
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
497
                speculator=speculator,
498
499
500
501
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

502
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
503
504
        if sharded:
            if FLASH_ATTENTION:
505
                if config_dict.get("alibi", False):
506
507
508
509
510
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
511
                    speculator=speculator,
512
                    dtype=dtype,
513
514
                    trust_remote_code=trust_remote_code,
                )
515
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
516
        else:
517
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
518
                return FlashRWSharded(
519
520
521
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
522
                    speculator=speculator,
523
                    dtype=dtype,
524
525
526
527
528
529
530
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
531
                    speculator=speculator,
532
                    dtype=dtype,
533
534
535
                    trust_remote_code=trust_remote_code,
                )

536
    if model_type == "mistral":
537
538
539
540
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
541
542
543
544
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
545
                speculator=speculator,
546
547
548
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
549
550
551
552
553
554
555
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
556
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
557
558
559
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
560
561

    if model_type == "mixtral":
562
563
564
565
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
566
567
568
569
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
570
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
571
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
572
573
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
574
575
576
577
578
579
580
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
581
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
582
583
584
585
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
586
587
588
589
590
591
592
593
594
595
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
596
597
598
599
600
601
602
603
604
605
606
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
607
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
631
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
632
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
633
634
                trust_remote_code=trust_remote_code,
            )
635
636

    if model_type == "opt":
637
        return OPTSharded(
638
639
640
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
641
            speculator=speculator,
642
643
            dtype=dtype,
            trust_remote_code=trust_remote_code,
644
        )
645

646
    if model_type == "t5":
647
648
649
650
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
651
            speculator=speculator,
652
            dtype=dtype,
653
654
            trust_remote_code=trust_remote_code,
        )
655
    if model_type == "idefics":
656
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
657
658
659
660
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
661
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
662
663
664
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
665
666
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
667
668
669
670
671
672
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
673
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
674
675
676
677
678
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
679

680
681
682
683
684
685
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
686
                speculator=speculator,
687
688
689
690
691
692
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

693
    if sharded:
694
        raise NotImplementedError("sharded is not supported for AutoModel")
695
    if quantize == "gptq":
696
        raise NotImplementedError(
697
698
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
699
    if quantize == "awq":
700
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
701
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
702
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
703
    elif quantize == "eetq":
704
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
705
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
706
        return CausalLM(
707
708
709
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
710
            speculator=speculator,
711
712
            dtype=dtype,
            trust_remote_code=trust_remote_code,
713
        )
714
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
715
        return Seq2SeqLM(
716
717
718
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
719
            speculator=speculator,
720
721
            dtype=dtype,
            trust_remote_code=trust_remote_code,
722
723
        )

724
    auto_map = config_dict.get("auto_map", None)
725
726
727
728
729
730
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
731
                speculator=speculator,
732
                dtype=dtype,
733
734
                trust_remote_code=trust_remote_code,
            )
735
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
736
737
738
739
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
740
                speculator=speculator,
741
                dtype=dtype,
742
743
                trust_remote_code=trust_remote_code,
            )
744
745

    raise ValueError(f"Unsupported model type {model_type}")