"vscode:/vscode.git/clone" did not exist on "f3aea78fb642967838e7b5b1940a25fe67f4f7a9"
__init__.py 25.4 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
7
from huggingface_hub import hf_hub_download, HfApi
8
from typing import Optional
9
from pathlib import Path
10

Nicolas Patry's avatar
Nicolas Patry committed
11
from text_generation_server.utils.speculate import get_speculate, set_speculate
12
13
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
14
from text_generation_server.models.flash_causal_lm import FlashCausalLM
15
from text_generation_server.models.bloom import BLOOMSharded
16
from text_generation_server.models.mpt import MPTSharded
17
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
18
from text_generation_server.models.rw import RW
19
20
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
21
22
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
23
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
24
from text_generation_server.models.phi import Phi
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51

52
try:
53
    from text_generation_server.models.flash_rw import FlashRWSharded
54
    from text_generation_server.models.flash_gpt2 import FlashGPT2
55
56
57
58
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
59
60
61
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
65
66
67
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
drbh's avatar
drbh committed
68
69
70
    from text_generation_server.models.pali_gemma import (
        PaliGemma,
    )
71
72
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
73
    )
74
    from text_generation_server.models.idefics import IDEFICSSharded
75
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
76
    from text_generation_server.models.idefics2 import Idefics2
77
78
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
79
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
80
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
81
    from text_generation_server.models.flash_dbrx import FlashDbrx
fxmarty's avatar
fxmarty committed
82
83
84
85
    from text_generation_server.utils.flash_attn import (
        HAS_FLASH_ATTN_V2_CUDA,
        HAS_FLASH_ATTN_V2_ROCM,
    )
86
87
except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
88
    FLASH_ATTENTION = False
89
    HAS_FLASH_ATTN_V2_CUDA = False
fxmarty's avatar
fxmarty committed
90
    HAS_FLASH_ATTN_V2_ROCM = False
91

92
if FLASH_ATTENTION:
93
    __all__.append(FlashGPT2)
94
    __all__.append(FlashNeoXSharded)
95
    __all__.append(FlashRWSharded)
96
    __all__.append(FlashSantacoderSharded)
97
    __all__.append(FlashLlama)
98
    __all__.append(IDEFICSSharded)
99
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
100
    __all__.append(FlashMixtral)
101
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
102
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
103
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
104
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
105
106
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
107

drbh's avatar
drbh committed
108
109
110
111
112
113
114
115
116
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
117

118

119
def get_model(
120
121
122
123
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
124
    speculate: Optional[int],
125
    dtype: Optional[str],
126
    trust_remote_code: bool,
127
) -> Model:
128
    if dtype is None:
129
130
131
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
132
133
134
135
136
137
138
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
139
140
141
142
143
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
144
145
146
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
147
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
148

Nicolas Patry's avatar
Nicolas Patry committed
149
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
150
    if "medusa_num_heads" in config_dict:
151
152
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
153
154
155
156
157
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
158
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
159
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
160
                )
Nicolas Patry's avatar
Nicolas Patry committed
161
162
163
164
165
166
167
168
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
169
170
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
171
172
173
174
175
176
177
178
179
180
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
181
182
183
184
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
185
        else:
Nicolas Patry's avatar
Nicolas Patry committed
186
187
188
189
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
190

Nicolas Patry's avatar
Nicolas Patry committed
191
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
245
246
247
248
249
250
251
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
252
253
254
255
256
257
258
259
260
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
261
262
263
264
265
266
267
268
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
269
270
271
272
273
274

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
275
            speculator=speculator,
drbh's avatar
drbh committed
276
277
278
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
279

OlivierDehaene's avatar
OlivierDehaene committed
280
281
282
283
284
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
285
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
286
287
288
289
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

290
291
292
293
294
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
295
        if FLASH_ATTENTION:
296
297
298
299
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
300
                speculator=speculator,
301
                dtype=dtype,
302
303
                trust_remote_code=trust_remote_code,
            )
304
305
306
307
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
308
        else:
309
            return SantaCoder(
310
311
312
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
313
                speculator=speculator,
314
                dtype=dtype,
315
316
                trust_remote_code=trust_remote_code,
            )
317

318
    if model_type == "bloom":
319
        return BLOOMSharded(
320
321
322
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
323
            speculator=speculator,
324
325
            dtype=dtype,
            trust_remote_code=trust_remote_code,
326
        )
327
328
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
329
330
331
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
332
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
333
334
            dtype=dtype,
            trust_remote_code=trust_remote_code,
335
        )
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    elif model_type == "gpt2":
        if FLASH_ATTENTION:
            return FlashGPT2(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
357
358
359
360
361
362
    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
363
                speculator=speculator,
364
                dtype=dtype,
365
366
367
368
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
369
370
371
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
372
                speculator=speculator,
373
                dtype=dtype,
374
375
                trust_remote_code=trust_remote_code,
            )
376
        else:
377
            return CausalLM(
378
379
380
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
381
                speculator=speculator,
382
                dtype=dtype,
383
384
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
385

drbh's avatar
drbh committed
386
387
388
389
390
391
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
392
                speculator=speculator,
drbh's avatar
drbh committed
393
394
395
396
397
398
399
400
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
401
                speculator=speculator,
drbh's avatar
drbh committed
402
403
404
405
406
407
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
408
409
410
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
411
412
413
414
415
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
416
                speculator=speculator,
drbh's avatar
drbh committed
417
418
419
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
420

Nicolas Patry's avatar
Nicolas Patry committed
421
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
422
423
        if FLASH_ATTENTION:
            return FlashLlama(
424
425
426
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
427
                speculator=speculator,
428
                dtype=dtype,
429
430
                trust_remote_code=trust_remote_code,
            )
431
432
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
433
        else:
434
            return CausalLM(
435
436
437
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
438
                speculator=speculator,
439
                dtype=dtype,
440
441
                trust_remote_code=trust_remote_code,
            )
442
443
444
445
446
447
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
448
                speculator=speculator,
449
450
451
452
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
453
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
454
455
456
457
458
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
459
                speculator=speculator,
460
461
462
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
463

OlivierDehaene's avatar
OlivierDehaene committed
464
465
466
467
468
469
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
470
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
471
472
473
474
475
476
477
478
479
480
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
481
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
482
483
484
485
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

486
487
488
489
490
491
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
492
                speculator=speculator,
493
494
495
496
497
498
499
500
501
502
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
503
                speculator=speculator,
504
505
506
507
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

508
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
509
510
        if sharded:
            if FLASH_ATTENTION:
511
                if config_dict.get("alibi", False):
512
513
514
515
516
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
517
                    speculator=speculator,
518
                    dtype=dtype,
519
520
                    trust_remote_code=trust_remote_code,
                )
521
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
522
        else:
523
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
524
                return FlashRWSharded(
525
526
527
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
528
                    speculator=speculator,
529
                    dtype=dtype,
530
531
532
533
534
535
536
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
537
                    speculator=speculator,
538
                    dtype=dtype,
539
540
541
                    trust_remote_code=trust_remote_code,
                )

542
    if model_type == "mistral":
543
544
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
545
546
547
548
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
549
550
551
552
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
553
                speculator=speculator,
554
555
556
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
557
558
559
560
561
562
563
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
564
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
565
566
567
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
568
569

    if model_type == "mixtral":
570
571
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
572
573
574
575
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
576
577
578
579
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
580
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
581
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
582
583
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
584
585
586
587
588
589
590
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
591
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
592
593
594
595
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
596
597
598
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
599
600
601
602
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
603
604
605
606
607
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
608
609
610
611
612
613
614
615
616
617
618
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
619
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
620
621
622
623
624
625
626
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
fxmarty's avatar
fxmarty committed
627
628
629
630
            ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
            or HAS_FLASH_ATTN_V2_CUDA
            or HAS_FLASH_ATTN_V2_ROCM
        ):
OlivierDehaene's avatar
OlivierDehaene committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
645
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
646
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
647
648
                trust_remote_code=trust_remote_code,
            )
649
650

    if model_type == "opt":
651
        return OPTSharded(
652
653
654
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
655
            speculator=speculator,
656
657
            dtype=dtype,
            trust_remote_code=trust_remote_code,
658
        )
659

660
    if model_type == "t5":
661
662
663
664
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
665
            speculator=speculator,
666
            dtype=dtype,
667
668
            trust_remote_code=trust_remote_code,
        )
669
    if model_type == "idefics":
670
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
671
672
673
674
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
675
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
676
677
678
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
679
680
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
681
682
683
684
685
686
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
687
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
688
689
690
691
692
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
drbh's avatar
drbh committed
693
694
695
696
697
698
699
700
701
702
703
704
    if model_type == "paligemma":
        if FLASH_ATTENTION:
            return PaliGemma(
                model_id,
                revision,
                quantize=quantize,
                speculator=speculator,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
705

706
707
708
709
710
711
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
712
                speculator=speculator,
713
714
715
716
717
718
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

719
    if sharded:
720
        raise NotImplementedError("sharded is not supported for AutoModel")
721
    if quantize == "gptq":
722
        raise NotImplementedError(
723
724
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
725
    if quantize == "awq":
726
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
727
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
728
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
729
    elif quantize == "eetq":
730
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
731
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
732
        return CausalLM(
733
734
735
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
736
            speculator=speculator,
737
738
            dtype=dtype,
            trust_remote_code=trust_remote_code,
739
        )
740
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
741
        return Seq2SeqLM(
742
743
744
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
745
            speculator=speculator,
746
747
            dtype=dtype,
            trust_remote_code=trust_remote_code,
748
749
        )

750
    auto_map = config_dict.get("auto_map", None)
751
752
753
754
755
756
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
757
                speculator=speculator,
758
                dtype=dtype,
759
760
                trust_remote_code=trust_remote_code,
            )
761
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
762
763
764
765
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
766
                speculator=speculator,
767
                dtype=dtype,
768
769
                trust_remote_code=trust_remote_code,
            )
770
771

    raise ValueError(f"Unsupported model type {model_type}")