__init__.py 23.9 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3

4
from loguru import logger
5
from transformers.configuration_utils import PretrainedConfig
6
from transformers.models.auto import modeling_auto
Nicolas Patry's avatar
Nicolas Patry committed
7
from huggingface_hub import hf_hub_download, HfApi
8
from typing import Optional
9
from pathlib import Path
10

Nicolas Patry's avatar
Nicolas Patry committed
11
from text_generation_server.utils.speculate import get_speculate, set_speculate
12
13
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
14
from text_generation_server.models.flash_causal_lm import FlashCausalLM
15
from text_generation_server.models.bloom import BLOOMSharded
16
from text_generation_server.models.mpt import MPTSharded
17
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
18
from text_generation_server.models.rw import RW
19
20
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
21
22
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
23
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
24
from text_generation_server.models.phi import Phi
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

48
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
49

50
FLASH_ATTENTION = True
51

52
try:
53
54
55
56
57
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
61
62
63
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
64
65
66
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
67
68
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
69
    )
70
    from text_generation_server.models.idefics import IDEFICSSharded
71
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
72
    from text_generation_server.models.idefics2 import Idefics2
73
74
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
75
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
76
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
77
    from text_generation_server.models.flash_dbrx import FlashDbrx
78
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
79
80
81

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
82
    FLASH_ATTENTION = False
83
    HAS_FLASH_ATTN_V2_CUDA = False
84

85
if FLASH_ATTENTION:
86
    __all__.append(FlashNeoXSharded)
87
    __all__.append(FlashRWSharded)
88
    __all__.append(FlashSantacoderSharded)
89
    __all__.append(FlashLlama)
90
    __all__.append(IDEFICSSharded)
91
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
92
    __all__.append(FlashMixtral)
93
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
94
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
95
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
96
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
97
98
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
99

drbh's avatar
drbh committed
100
101
102
103
104
105
106
107
108
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
109

110

111
def get_model(
112
113
114
115
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
116
    speculate: Optional[int],
117
    dtype: Optional[str],
118
    trust_remote_code: bool,
119
) -> Model:
120
    if dtype is None:
121
122
123
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
124
125
126
127
128
129
130
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
131
132
133
134
135
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
136
137
138
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
139
    model_type = config_dict.get("model_type", None)
Nicolas Patry's avatar
Nicolas Patry committed
140

Nicolas Patry's avatar
Nicolas Patry committed
141
    speculator = None
Nicolas Patry's avatar
Nicolas Patry committed
142
    if "medusa_num_heads" in config_dict:
143
144
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
145
146
147
148
149
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
150
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
151
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
152
                )
Nicolas Patry's avatar
Nicolas Patry committed
153
154
155
156
157
158
159
160
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
Nicolas Patry's avatar
Nicolas Patry committed
161
162
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
163
164
165
166
167
168
169
170
171
172
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
Nicolas Patry's avatar
Nicolas Patry committed
173
174
175
176
            speculator = {
                "path": Path(medusa_config).parent,
                "model_paths": ["medusa_lm_head.safetensors"],
            }
177
        else:
Nicolas Patry's avatar
Nicolas Patry committed
178
179
180
181
            speculator = {
                "path": Path(medusa_model_id),
                "model_paths": ["medusa_lm_head.safetensors"],
            }
182

Nicolas Patry's avatar
Nicolas Patry committed
183
        method = "medusa"
Nicolas Patry's avatar
Nicolas Patry committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    elif model_type == "mlp_speculator":
        mlp_model_id = model_id
        mlp_revision = revision
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_mlp = config_dict["n_predict"]
        if speculate is not None:
            if speculate > speculate_mlp:
                raise RuntimeError(
                    f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
                )
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_mlp)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        # Reload model type from parent.
        model_type = config_dict.get("model_type", None)
        is_local = Path(mlp_model_id).exists()
        extension = ".safetensors"
        if not is_local:
            mlp_speculator_config = hf_hub_download(
                mlp_model_id, revision=mlp_revision, filename="config.json"
            )
            api = HfApi()
            info = api.model_info(mlp_model_id, revision=mlp_revision)
            filenames = [
                s.rfilename
                for s in info.siblings
                if s.rfilename.endswith(extension)
                and len(s.rfilename.split("/")) == 1
                and "arguments" not in s.rfilename
                and "args" not in s.rfilename
                and "training" not in s.rfilename
            ]
            for filename in filenames:
                hf_hub_download(
                    mlp_model_id,
                    revision=mlp_revision,
                    filename=filename,
                )
            speculator = {
                "path": Path(mlp_speculator_config).parent,
                "model_paths": filenames,
            }
        else:
            speculator = Path(mlp_model_id)
            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
            speculator = {"path": speculator, "model_paths": filenames}
        method = "mlp_speculator"
Nicolas Patry's avatar
Nicolas Patry committed
237
238
239
240
241
242
243
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
244
245
246
247
248
249
250
251
252
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
253
254
255
256
257
258
259
260
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
261
262
263
264
265
266

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
267
            speculator=speculator,
drbh's avatar
drbh committed
268
269
270
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
271

OlivierDehaene's avatar
OlivierDehaene committed
272
273
274
275
276
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
277
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
278
279
280
281
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

282
283
284
285
286
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
287
        if FLASH_ATTENTION:
288
289
290
291
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
292
                speculator=speculator,
293
                dtype=dtype,
294
295
                trust_remote_code=trust_remote_code,
            )
296
297
298
299
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
300
        else:
301
            return SantaCoder(
302
303
304
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
305
                speculator=speculator,
306
                dtype=dtype,
307
308
                trust_remote_code=trust_remote_code,
            )
309

310
    if model_type == "bloom":
311
        return BLOOMSharded(
312
313
314
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
315
            speculator=speculator,
316
317
            dtype=dtype,
            trust_remote_code=trust_remote_code,
318
        )
319
320
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
321
322
323
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
324
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
325
326
            dtype=dtype,
            trust_remote_code=trust_remote_code,
327
        )
328
329
330
331
332
333
334

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
335
                speculator=speculator,
336
                dtype=dtype,
337
338
339
340
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
341
342
343
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
344
                speculator=speculator,
345
                dtype=dtype,
346
347
                trust_remote_code=trust_remote_code,
            )
348
        else:
349
            return CausalLM(
350
351
352
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
353
                speculator=speculator,
354
                dtype=dtype,
355
356
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
357

drbh's avatar
drbh committed
358
359
360
361
362
363
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
364
                speculator=speculator,
drbh's avatar
drbh committed
365
366
367
368
369
370
371
372
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
373
                speculator=speculator,
drbh's avatar
drbh committed
374
375
376
377
378
379
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
380
381
382
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
383
384
385
386
387
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
388
                speculator=speculator,
drbh's avatar
drbh committed
389
390
391
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
392

Nicolas Patry's avatar
Nicolas Patry committed
393
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
394
395
        if FLASH_ATTENTION:
            return FlashLlama(
396
397
398
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
399
                speculator=speculator,
400
                dtype=dtype,
401
402
                trust_remote_code=trust_remote_code,
            )
403
404
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
405
        else:
406
            return CausalLM(
407
408
409
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
410
                speculator=speculator,
411
                dtype=dtype,
412
413
                trust_remote_code=trust_remote_code,
            )
414
415
416
417
418
419
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
420
                speculator=speculator,
421
422
423
424
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
425
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
426
427
428
429
430
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
431
                speculator=speculator,
432
433
434
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
435

OlivierDehaene's avatar
OlivierDehaene committed
436
437
438
439
440
441
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
442
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
443
444
445
446
447
448
449
450
451
452
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
453
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
454
455
456
457
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

458
459
460
461
462
463
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
464
                speculator=speculator,
465
466
467
468
469
470
471
472
473
474
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
475
                speculator=speculator,
476
477
478
479
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

480
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
481
482
        if sharded:
            if FLASH_ATTENTION:
483
                if config_dict.get("alibi", False):
484
485
486
487
488
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
489
                    speculator=speculator,
490
                    dtype=dtype,
491
492
                    trust_remote_code=trust_remote_code,
                )
493
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
494
        else:
495
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
496
                return FlashRWSharded(
497
498
499
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
500
                    speculator=speculator,
501
                    dtype=dtype,
502
503
504
505
506
507
508
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
509
                    speculator=speculator,
510
                    dtype=dtype,
511
512
513
                    trust_remote_code=trust_remote_code,
                )

514
    if model_type == "mistral":
515
516
517
518
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
519
520
521
522
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
523
                speculator=speculator,
524
525
526
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
527
528
529
530
531
532
533
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
534
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
535
536
537
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
538
539

    if model_type == "mixtral":
540
541
542
543
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
544
545
546
547
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
548
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
549
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
550
551
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
552
553
554
555
556
557
558
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
559
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
560
561
562
563
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
564
565
566
567
568
569
570
571
572
573
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
574
575
576
577
578
579
580
581
582
583
584
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
585
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
609
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
610
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
611
612
                trust_remote_code=trust_remote_code,
            )
613
614

    if model_type == "opt":
615
        return OPTSharded(
616
617
618
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
619
            speculator=speculator,
620
621
            dtype=dtype,
            trust_remote_code=trust_remote_code,
622
        )
623

624
    if model_type == "t5":
625
626
627
628
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
629
            speculator=speculator,
630
            dtype=dtype,
631
632
            trust_remote_code=trust_remote_code,
        )
633
    if model_type == "idefics":
634
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
635
636
637
638
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
639
                speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
640
641
642
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
643
644
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
645
646
647
648
649
650
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
651
                speculator=speculator,
Nicolas Patry's avatar
Nicolas Patry committed
652
653
654
655
656
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
657

658
659
660
661
662
663
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
664
                speculator=speculator,
665
666
667
668
669
670
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

671
    if sharded:
672
        raise NotImplementedError("sharded is not supported for AutoModel")
673
    if quantize == "gptq":
674
        raise NotImplementedError(
675
676
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
677
    if quantize == "awq":
678
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
679
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
680
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
681
    elif quantize == "eetq":
682
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
683
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
684
        return CausalLM(
685
686
687
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
688
            speculator=speculator,
689
690
            dtype=dtype,
            trust_remote_code=trust_remote_code,
691
        )
692
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
693
        return Seq2SeqLM(
694
695
696
            model_id,
            revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
697
            speculator=speculator,
698
699
            dtype=dtype,
            trust_remote_code=trust_remote_code,
700
701
        )

702
    auto_map = config_dict.get("auto_map", None)
703
704
705
706
707
708
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
709
                speculator=speculator,
710
                dtype=dtype,
711
712
                trust_remote_code=trust_remote_code,
            )
713
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
714
715
716
717
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
718
                speculator=speculator,
719
                dtype=dtype,
720
721
                trust_remote_code=trust_remote_code,
            )
722
723

    raise ValueError(f"Unsupported model type {model_type}")