__init__.py 21.6 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
71
    from text_generation_server.models.idefics2 import Idefics2
72
    from text_generation_server.models.flash_mistral import FlashMistral
73
    # from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
74
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
75
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
76
    # from text_generation_server.models.flash_dbrx import FlashDbrx
77
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
78
79
80

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
81
    FLASH_ATTENTION = False
82
    HAS_FLASH_ATTN_V2_CUDA = False
83

84
if FLASH_ATTENTION:
85
    __all__.append(FlashNeoXSharded)
86
    __all__.append(FlashRWSharded)
87
    __all__.append(FlashSantacoderSharded)
88
    __all__.append(FlashLlama)
89
    __all__.append(IDEFICSSharded)
90
    __all__.append(FlashMistral)
91
92
    # __all__.append(FlashMixtral)
    # __all__.append(FlashDbrx)
drbh's avatar
drbh committed
93
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
94
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
95
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
96
97
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
98

drbh's avatar
drbh committed
99
100
101
102
103
104
105
106
107
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
108

109

110
def get_model(
111
112
113
114
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
115
    speculate: Optional[int],
116
    dtype: Optional[str],
117
    trust_remote_code: bool,
118
) -> Model:
119
    if dtype is None:
120
121
122
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
123
124
125
126
127
128
129
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
130
131
132
133
134
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
135
136
137
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
138
139
140

    use_medusa = None
    if "medusa_num_heads" in config_dict:
141
142
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
143
144
145
146
147
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
148
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
149
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
150
                )
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
156
157
158
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
173
174
175
176
177
178
179
180
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
181
182
183
184
185
186
187
188
189
190
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
191
192
193
194
195
196
197
198
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
199
200
201
202
203
204

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
205
            use_medusa=use_medusa,
drbh's avatar
drbh committed
206
207
208
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
209

OlivierDehaene's avatar
OlivierDehaene committed
210
211
212
213
214
215
216
217
218
219
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

220
221
222
223
224
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
225
        if FLASH_ATTENTION:
226
227
228
229
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
230
                use_medusa=use_medusa,
231
                dtype=dtype,
232
233
                trust_remote_code=trust_remote_code,
            )
234
235
236
237
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
238
        else:
239
            return SantaCoder(
240
241
242
                model_id,
                revision,
                quantize=quantize,
243
                use_medusa=use_medusa,
244
                dtype=dtype,
245
246
                trust_remote_code=trust_remote_code,
            )
247

248
    if model_type == "bloom":
249
        return BLOOMSharded(
250
251
252
            model_id,
            revision,
            quantize=quantize,
253
            use_medusa=use_medusa,
254
255
            dtype=dtype,
            trust_remote_code=trust_remote_code,
256
        )
257
258
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
259
260
261
            model_id,
            revision,
            quantize=quantize,
262
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
263
264
            dtype=dtype,
            trust_remote_code=trust_remote_code,
265
        )
266
267
268
269
270
271
272

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
273
                use_medusa=use_medusa,
274
                dtype=dtype,
275
276
277
278
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
279
280
281
                model_id,
                revision,
                quantize=quantize,
282
                use_medusa=use_medusa,
283
                dtype=dtype,
284
285
                trust_remote_code=trust_remote_code,
            )
286
        else:
287
            return CausalLM(
288
289
290
                model_id,
                revision,
                quantize=quantize,
291
                use_medusa=use_medusa,
292
                dtype=dtype,
293
294
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
295

drbh's avatar
drbh committed
296
297
298
299
300
301
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
302
                use_medusa=use_medusa,
drbh's avatar
drbh committed
303
304
305
306
307
308
309
310
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
311
                use_medusa=use_medusa,
drbh's avatar
drbh committed
312
313
314
315
316
317
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
318
319
320
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
321
322
323
324
325
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
326
                use_medusa=use_medusa,
drbh's avatar
drbh committed
327
328
329
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
330

Nicolas Patry's avatar
Nicolas Patry committed
331
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
332
333
        if FLASH_ATTENTION:
            return FlashLlama(
334
335
336
                model_id,
                revision,
                quantize=quantize,
337
                use_medusa=use_medusa,
338
                dtype=dtype,
339
340
                trust_remote_code=trust_remote_code,
            )
341
342
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
343
        else:
344
            return CausalLM(
345
346
347
                model_id,
                revision,
                quantize=quantize,
348
                use_medusa=use_medusa,
349
                dtype=dtype,
350
351
                trust_remote_code=trust_remote_code,
            )
352
353
354
355
356
357
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
358
                use_medusa=use_medusa,
359
360
361
362
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
363
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
364
365
366
367
368
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
369
                use_medusa=use_medusa,
370
371
372
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
373

OlivierDehaene's avatar
OlivierDehaene committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

418
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
419
420
        if sharded:
            if FLASH_ATTENTION:
421
                if config_dict.get("alibi", False):
422
423
424
425
426
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
427
                    use_medusa=use_medusa,
428
                    dtype=dtype,
429
430
                    trust_remote_code=trust_remote_code,
                )
431
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
432
        else:
433
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
434
                return FlashRWSharded(
435
436
437
                    model_id,
                    revision,
                    quantize=quantize,
438
                    use_medusa=use_medusa,
439
                    dtype=dtype,
440
441
442
443
444
445
446
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
447
                    use_medusa=use_medusa,
448
                    dtype=dtype,
449
450
451
                    trust_remote_code=trust_remote_code,
                )

452
    if model_type == "mistral":
453
454
455
456
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
457
458
459
460
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
461
                use_medusa=use_medusa,
462
463
464
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
465
466
467
468
469
470
471
472
473
474
475
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
476
477

    if model_type == "mixtral":
478
479
480
481
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
482
483
484
485
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
486
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
487
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
488
489
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
490
491
492
493
494
495
496
497
498
499
500
501
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
502
503
504
505
506
507
508
509
510
511
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
huangwb's avatar
huangwb committed
530
531
        use_sliding_window = config_dict.get("use_sliding_window", False)
        sliding_window = sliding_window if use_sliding_window else None
OlivierDehaene's avatar
OlivierDehaene committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
551
552
                trust_remote_code=trust_remote_code,
            )
553
554

    if model_type == "opt":
555
        return OPTSharded(
556
557
558
            model_id,
            revision,
            quantize=quantize,
559
            use_medusa=use_medusa,
560
561
            dtype=dtype,
            trust_remote_code=trust_remote_code,
562
        )
563

564
    if model_type == "t5":
565
566
567
568
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
569
            use_medusa=use_medusa,
570
            dtype=dtype,
571
572
            trust_remote_code=trust_remote_code,
        )
573
    if model_type == "idefics":
574
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
575
576
577
578
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
579
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
580
581
582
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
583
584
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
585
586
587
588
589
590
591
592
593
594
595
596
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
597

598
599
600
601
602
603
604
605
606
607
608
609
610
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

611
    if sharded:
612
        raise NotImplementedError("sharded is not supported for AutoModel")
613
    if quantize == "gptq":
614
        raise NotImplementedError(
615
616
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
617
    if quantize == "awq":
618
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
619
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
620
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
621
    elif quantize == "eetq":
622
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
623
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
624
        return CausalLM(
625
626
627
            model_id,
            revision,
            quantize=quantize,
628
            use_medusa=use_medusa,
629
630
            dtype=dtype,
            trust_remote_code=trust_remote_code,
631
        )
632
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
633
        return Seq2SeqLM(
634
635
636
            model_id,
            revision,
            quantize=quantize,
637
            use_medusa=use_medusa,
638
639
            dtype=dtype,
            trust_remote_code=trust_remote_code,
640
641
        )

642
    auto_map = config_dict.get("auto_map", None)
643
644
645
646
647
648
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
649
                use_medusa=use_medusa,
650
                dtype=dtype,
651
652
                trust_remote_code=trust_remote_code,
            )
653
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
654
655
656
657
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
658
                use_medusa=use_medusa,
659
                dtype=dtype,
660
661
                trust_remote_code=trust_remote_code,
            )
662
663

    raise ValueError(f"Unsupported model type {model_type}")