"src/diffusers/schedulers/scheduling_ddpm_parallel.py" did not exist on "8b451eb63b0f101e7fcc72365fe0d683808b22cd"
__init__.py 25.1 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
7
from huggingface_hub import hf_hub_download, HfApi
8
from typing import Optional
9
from pathlib import Path
10

Nicolas Patry's avatar
Nicolas Patry committed
11
from text_generation_server.utils.speculate import get_speculate, set_speculate
12
13
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
14
from text_generation_server.models.flash_causal_lm import FlashCausalLM
15
from text_generation_server.models.bloom import BLOOMSharded
16
from text_generation_server.models.mpt import MPTSharded
17
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
18
from text_generation_server.models.rw import RW
19
20
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
21
22
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
23
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
24
from text_generation_server.models.phi import Phi
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51

52
try:
53
    from text_generation_server.models.flash_rw import FlashRWSharded
54
    from text_generation_server.models.flash_gpt2 import FlashGPT2
55
56
57
58
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
59
60
61
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
65
66
67
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
drbh's avatar
drbh committed
68
69
70
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
71
72
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
73
    )
74
    from text_generation_server.models.idefics import IDEFICSSharded
75
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
76
    from text_generation_server.models.idefics2 import Idefics2
77
78
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
79
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
80
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
81
    from text_generation_server.models.flash_dbrx import FlashDbrx
82
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
83
84
85

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
86
    FLASH_ATTENTION = False
87
    HAS_FLASH_ATTN_V2_CUDA = False
88

89
if FLASH_ATTENTION:
90
    __all__.append(FlashGPT2)
91
    __all__.append(FlashNeoXSharded)
92
    __all__.append(FlashRWSharded)
93
    __all__.append(FlashSantacoderSharded)
94
    __all__.append(FlashLlama)
95
    __all__.append(IDEFICSSharded)
96
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
97
    __all__.append(FlashMixtral)
98
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
99
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
100
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
101
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
102
103
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
104

drbh's avatar
drbh committed
105
106
107
108
109
110
111
112
113
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
114

115

116
def get_model(
117
118
119
120
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
121
    speculate: Optional[int],
122
    dtype: Optional[str],
123
    trust_remote_code: bool,
124
) -> Model:
125
    if dtype is None:
126
127
128
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
129
130
131
132
133
134
135
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
136
137
138
139
140
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
141
142
143
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
144
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
145

Nicolas Patry's avatar
Nicolas Patry committed
146
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
147
    if "medusa_num_heads" in config_dict:
148
149
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
150
151
152
153
154
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
155
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
156
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
157
                )
Nicolas Patry's avatar
Nicolas Patry committed
158
159
160
161
162
163
164
165
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
166
167
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
168
169
170
171
172
173
174
175
176
177
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
178
179
180
181
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
182
        else:
Nicolas Patry's avatar
Nicolas Patry committed
183
184
185
186
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
187

Nicolas Patry's avatar
Nicolas Patry committed
188
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
242
243
244
245
246
247
248
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
249
250
251
252
253
254
255
256
257
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
258
259
260
261
262
263
264
265
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
266
267
268
269
270
271

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
272
            speculator=speculator,
drbh's avatar
drbh committed
273
274
275
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
276

OlivierDehaene's avatar
OlivierDehaene committed
277
278
279
280
281
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
282
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
283
284
285
286
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

287
288
289
290
291
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
292
        if FLASH_ATTENTION:
293
294
295
296
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
297
                speculator=speculator,
298
                dtype=dtype,
299
300
                trust_remote_code=trust_remote_code,
            )
301
302
303
304
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
305
        else:
306
            return SantaCoder(
307
308
309
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
310
                speculator=speculator,
311
                dtype=dtype,
312
313
                trust_remote_code=trust_remote_code,
            )
314

315
    if model_type == "bloom":
316
        return BLOOMSharded(
317
318
319
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
320
            speculator=speculator,
321
322
            dtype=dtype,
            trust_remote_code=trust_remote_code,
323
        )
324
325
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
326
327
328
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
329
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
330
331
            dtype=dtype,
            trust_remote_code=trust_remote_code,
332
        )
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    elif model_type == "gpt2":
        if FLASH_ATTENTION:
            return FlashGPT2(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
354
355
356
357
358
359
    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
360
                speculator=speculator,
361
                dtype=dtype,
362
363
364
365
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
366
367
368
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
369
                speculator=speculator,
370
                dtype=dtype,
371
372
                trust_remote_code=trust_remote_code,
            )
373
        else:
374
            return CausalLM(
375
376
377
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
378
                speculator=speculator,
379
                dtype=dtype,
380
381
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
382

drbh's avatar
drbh committed
383
384
385
386
387
388
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
389
                speculator=speculator,
drbh's avatar
drbh committed
390
391
392
393
394
395
396
397
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
398
                speculator=speculator,
drbh's avatar
drbh committed
399
400
401
402
403
404
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
405
406
407
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
408
409
410
411
412
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
413
                speculator=speculator,
drbh's avatar
drbh committed
414
415
416
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
417

Nicolas Patry's avatar
Nicolas Patry committed
418
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
419
420
        if FLASH_ATTENTION:
            return FlashLlama(
421
422
423
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
424
                speculator=speculator,
425
                dtype=dtype,
426
427
                trust_remote_code=trust_remote_code,
            )
428
429
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
430
        else:
431
            return CausalLM(
432
433
434
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
435
                speculator=speculator,
436
                dtype=dtype,
437
438
                trust_remote_code=trust_remote_code,
            )
439
440
441
442
443
444
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
445
                speculator=speculator,
446
447
448
449
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
450
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
451
452
453
454
455
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
456
                speculator=speculator,
457
458
459
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
460

OlivierDehaene's avatar
OlivierDehaene committed
461
462
463
464
465
466
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
467
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
468
469
470
471
472
473
474
475
476
477
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
478
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
479
480
481
482
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

483
484
485
486
487
488
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
489
                speculator=speculator,
490
491
492
493
494
495
496
497
498
499
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
500
                speculator=speculator,
501
502
503
504
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

505
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
506
507
        if sharded:
            if FLASH_ATTENTION:
508
                if config_dict.get("alibi", False):
509
510
511
512
513
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
514
                    speculator=speculator,
515
                    dtype=dtype,
516
517
                    trust_remote_code=trust_remote_code,
                )
518
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
519
        else:
520
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
521
                return FlashRWSharded(
522
523
524
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
525
                    speculator=speculator,
526
                    dtype=dtype,
527
528
529
530
531
532
533
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
534
                    speculator=speculator,
535
                    dtype=dtype,
536
537
538
                    trust_remote_code=trust_remote_code,
                )

539
    if model_type == "mistral":
540
541
542
543
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
544
545
546
547
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
548
                speculator=speculator,
549
550
551
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
552
553
554
555
556
557
558
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
559
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
560
561
562
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
563
564

    if model_type == "mixtral":
565
566
567
568
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
569
570
571
572
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
573
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
574
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
575
576
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
577
578
579
580
581
582
583
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
584
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
585
586
587
588
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
589
590
591
592
593
594
595
596
597
598
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
599
600
601
602
603
604
605
606
607
608
609
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
610
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
634
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
635
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
636
637
                trust_remote_code=trust_remote_code,
            )
638
639

    if model_type == "opt":
640
        return OPTSharded(
641
642
643
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
644
            speculator=speculator,
645
646
            dtype=dtype,
            trust_remote_code=trust_remote_code,
647
        )
648

649
    if model_type == "t5":
650
651
652
653
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
654
            speculator=speculator,
655
            dtype=dtype,
656
657
            trust_remote_code=trust_remote_code,
        )
658
    if model_type == "idefics":
659
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
660
661
662
663
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
665
666
667
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
668
669
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
670
671
672
673
674
675
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
676
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
677
678
679
680
681
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
682
683
684
685
686
687
688
689
690
691
692
693
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
694

695
696
697
698
699
700
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
701
                speculator=speculator,
702
703
704
705
706
707
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

708
    if sharded:
709
        raise NotImplementedError("sharded is not supported for AutoModel")
710
    if quantize == "gptq":
711
        raise NotImplementedError(
712
713
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
714
    if quantize == "awq":
715
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
716
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
717
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
718
    elif quantize == "eetq":
719
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
720
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
721
        return CausalLM(
722
723
724
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
725
            speculator=speculator,
726
727
            dtype=dtype,
            trust_remote_code=trust_remote_code,
728
        )
729
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
730
        return Seq2SeqLM(
731
732
733
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
734
            speculator=speculator,
735
736
            dtype=dtype,
            trust_remote_code=trust_remote_code,
737
738
        )

739
    auto_map = config_dict.get("auto_map", None)
740
741
742
743
744
745
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
746
                speculator=speculator,
747
                dtype=dtype,
748
749
                trust_remote_code=trust_remote_code,
            )
750
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
751
752
753
754
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
755
                speculator=speculator,
756
                dtype=dtype,
757
758
                trust_remote_code=trust_remote_code,
            )
759
760

    raise ValueError(f"Unsupported model type {model_type}")