__init__.py 21.5 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
    from text_generation_server.models.llava_next import LlavaNext
Nicolas Patry's avatar
Nicolas Patry committed
71
    from text_generation_server.models.idefics2 import Idefics2
72
    from text_generation_server.models.flash_mistral import FlashMistral
73
    # from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
74
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
75
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
76
    # from text_generation_server.models.flash_dbrx import FlashDbrx
77
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
78
79
80

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
81
    FLASH_ATTENTION = False
82
    HAS_FLASH_ATTN_V2_CUDA = False
83

84
if FLASH_ATTENTION:
85
    __all__.append(FlashNeoXSharded)
86
    __all__.append(FlashRWSharded)
87
    __all__.append(FlashSantacoderSharded)
88
    __all__.append(FlashLlama)
89
    __all__.append(IDEFICSSharded)
90
    __all__.append(FlashMistral)
91
92
    # __all__.append(FlashMixtral)
    # __all__.append(FlashDbrx)
drbh's avatar
drbh committed
93
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
94
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
95
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
96
97
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
98

drbh's avatar
drbh committed
99
100
101
102
103
104
105
106
107
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
108

109

110
def get_model(
111
112
113
114
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
115
    speculate: Optional[int],
116
    dtype: Optional[str],
117
    trust_remote_code: bool,
118
) -> Model:
119
    if dtype is None:
120
121
122
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
123
124
125
126
127
128
129
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
130
131
132
133
134
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
135
136
137
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
138
139
140

    use_medusa = None
    if "medusa_num_heads" in config_dict:
141
142
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
143
144
145
146
147
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
148
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
149
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
150
                )
Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
156
157
158
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
173
174
175
176
177
178
179
180
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
181
182
183
184
185
186
187
188
189
190
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
191
192
193
194
195
196
197
198
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
199
200
201
202
203
204

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
205
            use_medusa=use_medusa,
drbh's avatar
drbh committed
206
207
208
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
209

OlivierDehaene's avatar
OlivierDehaene committed
210
211
212
213
214
215
216
217
218
219
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

220
221
222
223
224
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
225
        if FLASH_ATTENTION:
226
227
228
229
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
230
                use_medusa=use_medusa,
231
                dtype=dtype,
232
233
                trust_remote_code=trust_remote_code,
            )
234
235
236
237
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
238
        else:
239
            return SantaCoder(
240
241
242
                model_id,
                revision,
                quantize=quantize,
243
                use_medusa=use_medusa,
244
                dtype=dtype,
245
246
                trust_remote_code=trust_remote_code,
            )
247

248
    if model_type == "bloom":
249
        return BLOOMSharded(
250
251
252
            model_id,
            revision,
            quantize=quantize,
253
            use_medusa=use_medusa,
254
255
            dtype=dtype,
            trust_remote_code=trust_remote_code,
256
        )
257
258
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
259
260
261
            model_id,
            revision,
            quantize=quantize,
262
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
263
264
            dtype=dtype,
            trust_remote_code=trust_remote_code,
265
        )
266
267
268
269
270
271
272

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
273
                use_medusa=use_medusa,
274
                dtype=dtype,
275
276
277
278
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
279
280
281
                model_id,
                revision,
                quantize=quantize,
282
                use_medusa=use_medusa,
283
                dtype=dtype,
284
285
                trust_remote_code=trust_remote_code,
            )
286
        else:
287
            return CausalLM(
288
289
290
                model_id,
                revision,
                quantize=quantize,
291
                use_medusa=use_medusa,
292
                dtype=dtype,
293
294
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
295

drbh's avatar
drbh committed
296
297
298
299
300
301
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
302
                use_medusa=use_medusa,
drbh's avatar
drbh committed
303
304
305
306
307
308
309
310
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
311
                use_medusa=use_medusa,
drbh's avatar
drbh committed
312
313
314
315
316
317
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
318
319
320
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
321
322
323
324
325
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
326
                use_medusa=use_medusa,
drbh's avatar
drbh committed
327
328
329
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
330

Nicolas Patry's avatar
Nicolas Patry committed
331
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
332
333
        if FLASH_ATTENTION:
            return FlashLlama(
334
335
336
                model_id,
                revision,
                quantize=quantize,
337
                use_medusa=use_medusa,
338
                dtype=dtype,
339
340
                trust_remote_code=trust_remote_code,
            )
341
342
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
343
        else:
344
            return CausalLM(
345
346
347
                model_id,
                revision,
                quantize=quantize,
348
                use_medusa=use_medusa,
349
                dtype=dtype,
350
351
                trust_remote_code=trust_remote_code,
            )
352
353
354
355
356
357
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
358
                use_medusa=use_medusa,
359
360
361
362
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
363
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
364
365
366
367
368
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
369
                use_medusa=use_medusa,
370
371
372
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
373

OlivierDehaene's avatar
OlivierDehaene committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

418
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
419
420
        if sharded:
            if FLASH_ATTENTION:
421
                if config_dict.get("alibi", False):
422
423
424
425
426
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
427
                    use_medusa=use_medusa,
428
                    dtype=dtype,
429
430
                    trust_remote_code=trust_remote_code,
                )
431
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
432
        else:
433
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
434
                return FlashRWSharded(
435
436
437
                    model_id,
                    revision,
                    quantize=quantize,
438
                    use_medusa=use_medusa,
439
                    dtype=dtype,
440
441
442
443
444
445
446
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
447
                    use_medusa=use_medusa,
448
                    dtype=dtype,
449
450
451
                    trust_remote_code=trust_remote_code,
                )

452
    if model_type == "mistral":
453
454
455
456
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
457
458
459
460
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
461
                use_medusa=use_medusa,
462
463
464
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
465
466
467
468
469
470
471
472
473
474
475
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
476
477

    if model_type == "mixtral":
478
479
480
481
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
482
483
484
485
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
486
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
487
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
488
489
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
490
491
492
493
494
495
496
497
498
499
500
501
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
502
503
504
505
506
507
508
509
510
511
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
549
550
                trust_remote_code=trust_remote_code,
            )
551
552

    if model_type == "opt":
553
        return OPTSharded(
554
555
556
            model_id,
            revision,
            quantize=quantize,
557
            use_medusa=use_medusa,
558
559
            dtype=dtype,
            trust_remote_code=trust_remote_code,
560
        )
561

562
    if model_type == "t5":
563
564
565
566
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
567
            use_medusa=use_medusa,
568
            dtype=dtype,
569
570
            trust_remote_code=trust_remote_code,
        )
571
    if model_type == "idefics":
572
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
573
574
575
576
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
577
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
578
579
580
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
581
582
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
Nicolas Patry's avatar
Nicolas Patry committed
583
584
585
586
587
588
589
590
591
592
593
594
    if model_type == "idefics2":
        if FLASH_ATTENTION:
            return Idefics2(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
595

596
597
598
599
600
601
602
603
604
605
606
607
608
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

609
    if sharded:
610
        raise NotImplementedError("sharded is not supported for AutoModel")
611
    if quantize == "gptq":
612
        raise NotImplementedError(
613
614
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
615
    if quantize == "awq":
616
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
617
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
618
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
619
    elif quantize == "eetq":
620
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
621
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
622
        return CausalLM(
623
624
625
            model_id,
            revision,
            quantize=quantize,
626
            use_medusa=use_medusa,
627
628
            dtype=dtype,
            trust_remote_code=trust_remote_code,
629
        )
630
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
631
        return Seq2SeqLM(
632
633
634
            model_id,
            revision,
            quantize=quantize,
635
            use_medusa=use_medusa,
636
637
            dtype=dtype,
            trust_remote_code=trust_remote_code,
638
639
        )

640
    auto_map = config_dict.get("auto_map", None)
641
642
643
644
645
646
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
647
                use_medusa=use_medusa,
648
                dtype=dtype,
649
650
                trust_remote_code=trust_remote_code,
            )
651
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
652
653
654
655
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
656
                use_medusa=use_medusa,
657
                dtype=dtype,
658
659
                trust_remote_code=trust_remote_code,
            )
660
661

    raise ValueError(f"Unsupported model type {model_type}")