__init__.py 21 KB
Newer Older
1
2
import torch

3
from loguru import logger
4
from transformers.configuration_utils import PretrainedConfig
5
from transformers.models.auto import modeling_auto
6
from huggingface_hub import hf_hub_download
7
from typing import Optional
8
from pathlib import Path
9

Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate, set_speculate
11
12
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
13
from text_generation_server.models.flash_causal_lm import FlashCausalLM
14
from text_generation_server.models.bloom import BLOOMSharded
15
from text_generation_server.models.mpt import MPTSharded
16
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
17
from text_generation_server.models.rw import RW
18
19
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
20
21
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
22
from text_generation_server.models.gpt_neox import GPTNeoxSharded
drbh's avatar
drbh committed
23
from text_generation_server.models.phi import Phi
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
    "Model",
    "BLOOMSharded",
    "CausalLM",
    "GalacticaSharded",
    "Seq2SeqLM",
    "SantaCoder",
    "OPTSharded",
    "T5Sharded",
    "get_model",
]

47
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
48

49
FLASH_ATTENTION = True
50

51
try:
52
53
54
55
56
    from text_generation_server.models.flash_rw import FlashRWSharded
    from text_generation_server.models.flash_neox import FlashNeoXSharded
    from text_generation_server.models.flash_llama import (
        FlashLlama,
    )
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
    from text_generation_server.models.flash_qwen2 import (
        FlashQwen2,
    )
OlivierDehaene's avatar
OlivierDehaene committed
60
61
62
    from text_generation_server.models.flash_cohere import (
        FlashCohere,
    )
63
64
65
    from text_generation_server.models.flash_gemma import (
        FlashGemma,
    )
66
67
    from text_generation_server.models.flash_santacoder import (
        FlashSantacoderSharded,
68
    )
69
    from text_generation_server.models.idefics import IDEFICSSharded
70
    from text_generation_server.models.llava_next import LlavaNext
71
72
    from text_generation_server.models.flash_mistral import FlashMistral
    from text_generation_server.models.flash_mixtral import FlashMixtral
drbh's avatar
drbh committed
73
    from text_generation_server.models.flash_phi import FlashPhi
OlivierDehaene's avatar
OlivierDehaene committed
74
    from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
75
    from text_generation_server.models.flash_dbrx import FlashDbrx
76
    from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
77
78
79

except ImportError as e:
    logger.warning(f"Could not import Flash Attention enabled models: {e}")
80
    FLASH_ATTENTION = False
81
    HAS_FLASH_ATTN_V2_CUDA = False
82

83
if FLASH_ATTENTION:
84
    __all__.append(FlashNeoXSharded)
85
    __all__.append(FlashRWSharded)
86
    __all__.append(FlashSantacoderSharded)
87
    __all__.append(FlashLlama)
88
    __all__.append(IDEFICSSharded)
89
    __all__.append(FlashMistral)
OlivierDehaene's avatar
OlivierDehaene committed
90
    __all__.append(FlashMixtral)
91
    __all__.append(FlashDbrx)
drbh's avatar
drbh committed
92
    __all__.append(FlashPhi)
OlivierDehaene's avatar
OlivierDehaene committed
93
    __all__.append(FlashQwen2)
OlivierDehaene's avatar
OlivierDehaene committed
94
    __all__.append(FlashStarcoder2)
OlivierDehaene's avatar
OlivierDehaene committed
95
96
    __all__.append(FlashGemma)
    __all__.append(FlashCohere)
OlivierDehaene's avatar
OlivierDehaene committed
97

drbh's avatar
drbh committed
98
99
100
101
102
103
104
105
106
MAMBA_AVAILABLE = True
try:
    from text_generation_server.models.mamba import Mamba
except ImportError as e:
    logger.warning(f"Could not import Mamba: {e}")
    MAMBA_AVAILABLE = False

if MAMBA_AVAILABLE:
    __all__.append(Mamba)
OlivierDehaene's avatar
OlivierDehaene committed
107

108

109
def get_model(
110
111
112
113
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
114
    speculate: Optional[int],
115
    dtype: Optional[str],
116
    trust_remote_code: bool,
117
) -> Model:
118
    if dtype is None:
119
120
121
        # Keep it as default for now and let
        # every model resolve their own default dtype.
        dtype = None
122
123
124
125
126
127
128
    elif dtype == "float16":
        dtype = torch.float16
    elif dtype == "bfloat16":
        dtype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

Nicolas Patry's avatar
Nicolas Patry committed
129
130
131
132
133
    if speculate is not None:
        set_speculate(speculate)
    else:
        set_speculate(0)

OlivierDehaene's avatar
v0.8.2  
OlivierDehaene committed
134
135
136
    config_dict, _ = PretrainedConfig.get_config_dict(
        model_id, revision=revision, trust_remote_code=trust_remote_code
    )
Nicolas Patry's avatar
Nicolas Patry committed
137
138
139

    use_medusa = None
    if "medusa_num_heads" in config_dict:
140
141
        medusa_model_id = model_id
        medusa_revision = revision
Nicolas Patry's avatar
Nicolas Patry committed
142
143
144
145
146
        model_id = config_dict["base_model_name_or_path"]
        revision = "main"
        speculate_medusa = config_dict["medusa_num_heads"]
        if speculate is not None:
            if speculate > speculate_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
147
                raise RuntimeError(
OlivierDehaene's avatar
OlivierDehaene committed
148
                    f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
OlivierDehaene's avatar
OlivierDehaene committed
149
                )
Nicolas Patry's avatar
Nicolas Patry committed
150
151
152
153
154
155
156
157
            else:
                set_speculate(speculate)
        else:
            set_speculate(speculate_medusa)

        config_dict, _ = PretrainedConfig.get_config_dict(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        is_local = Path(medusa_model_id).exists()
        if not is_local:
            medusa_config = hf_hub_download(
                medusa_model_id, revision=medusa_revision, filename="config.json"
            )
            hf_hub_download(
                medusa_model_id,
                revision=medusa_revision,
                filename="medusa_lm_head.safetensors",
            )
            use_medusa = Path(medusa_config).parent
        else:
            use_medusa = Path(medusa_model_id)

Nicolas Patry's avatar
Nicolas Patry committed
172
173
174
175
176
177
178
179
        method = "medusa"
    else:
        method = "n-gram"

    speculate = get_speculate()
    if speculate > 0:
        logger.info(f"Using speculation {method} with {speculate} input ids.")

drbh's avatar
drbh committed
180
181
182
183
184
185
186
187
188
189
    model_type = config_dict.get("model_type", None)
    if model_type is None:
        # TODO: fix how we determine model type for Mamba
        if "ssm_cfg" in config_dict:
            # *only happens in Mamba case
            model_type = "ssm"
        else:
            raise RuntimeError(
                f"Could not determine model type for {model_id} revision {revision}"
            )
190
191
192
193
194
195
196
197
    quantization_config = config_dict.get("quantization_config", None)
    if quantization_config is not None and quantize is None:
        method = quantization_config.get("quant_method", None)
        if method in {"gptq", "awq"}:
            logger.info(f"Auto selecting quantization method {method}")
            quantize = method
        else:
            logger.info(f"Unknown quantization method {method}")
drbh's avatar
drbh committed
198
199
200
201
202
203

    if model_type == "ssm":
        return Mamba(
            model_id,
            revision,
            quantize=quantize,
204
            use_medusa=use_medusa,
drbh's avatar
drbh committed
205
206
207
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
208

OlivierDehaene's avatar
OlivierDehaene committed
209
210
211
212
213
214
215
216
217
218
    if model_id.startswith("facebook/galactica"):
        return GalacticaSharded(
            model_id,
            revision,
            quantize=quantize,
            use_medusa=use_medusa,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

219
220
221
222
223
    if (
        model_type == "gpt_bigcode"
        or model_type == "gpt2"
        and model_id.startswith("bigcode/")
    ):
224
        if FLASH_ATTENTION:
225
226
227
228
            return FlashSantacoderSharded(
                model_id,
                revision,
                quantize=quantize,
229
                use_medusa=use_medusa,
230
                dtype=dtype,
231
232
                trust_remote_code=trust_remote_code,
            )
233
234
235
236
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
            )
237
        else:
238
            return SantaCoder(
239
240
241
                model_id,
                revision,
                quantize=quantize,
242
                use_medusa=use_medusa,
243
                dtype=dtype,
244
245
                trust_remote_code=trust_remote_code,
            )
246

247
    if model_type == "bloom":
248
        return BLOOMSharded(
249
250
251
            model_id,
            revision,
            quantize=quantize,
252
            use_medusa=use_medusa,
253
254
            dtype=dtype,
            trust_remote_code=trust_remote_code,
255
        )
256
257
    elif model_type == "mpt":
        return MPTSharded(
OlivierDehaene's avatar
OlivierDehaene committed
258
259
260
            model_id,
            revision,
            quantize=quantize,
261
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
262
263
            dtype=dtype,
            trust_remote_code=trust_remote_code,
264
        )
265
266
267
268
269
270
271

    elif model_type == "gpt_neox":
        if FLASH_ATTENTION:
            return FlashNeoXSharded(
                model_id,
                revision,
                quantize=quantize,
272
                use_medusa=use_medusa,
273
                dtype=dtype,
274
275
276
277
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            return GPTNeoxSharded(
278
279
280
                model_id,
                revision,
                quantize=quantize,
281
                use_medusa=use_medusa,
282
                dtype=dtype,
283
284
                trust_remote_code=trust_remote_code,
            )
285
        else:
286
            return CausalLM(
287
288
289
                model_id,
                revision,
                quantize=quantize,
290
                use_medusa=use_medusa,
291
                dtype=dtype,
292
293
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
294

drbh's avatar
drbh committed
295
296
297
298
299
300
    elif model_type == "phi":
        if FLASH_ATTENTION:
            return FlashPhi(
                model_id,
                revision,
                quantize=quantize,
301
                use_medusa=use_medusa,
drbh's avatar
drbh committed
302
303
304
305
306
307
308
309
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
310
                use_medusa=use_medusa,
drbh's avatar
drbh committed
311
312
313
314
315
316
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    elif model_type == "phi-msft":
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
317
318
319
            raise NotImplementedError(
                "Legacy phi-msft is not supported with Flash Attention"
            )
drbh's avatar
drbh committed
320
321
322
323
324
        else:
            return Phi(
                model_id,
                revision,
                quantize=quantize,
325
                use_medusa=use_medusa,
drbh's avatar
drbh committed
326
327
328
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
329

Nicolas Patry's avatar
Nicolas Patry committed
330
    elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
331
332
        if FLASH_ATTENTION:
            return FlashLlama(
333
334
335
                model_id,
                revision,
                quantize=quantize,
336
                use_medusa=use_medusa,
337
                dtype=dtype,
338
339
                trust_remote_code=trust_remote_code,
            )
340
341
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
342
        else:
343
            return CausalLM(
344
345
346
                model_id,
                revision,
                quantize=quantize,
347
                use_medusa=use_medusa,
348
                dtype=dtype,
349
350
                trust_remote_code=trust_remote_code,
            )
351
352
353
354
355
356
    if model_type == "gemma":
        if FLASH_ATTENTION:
            return FlashGemma(
                model_id,
                revision,
                quantize=quantize,
357
                use_medusa=use_medusa,
358
359
360
361
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
OlivierDehaene's avatar
OlivierDehaene committed
362
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
363
364
365
366
367
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
368
                use_medusa=use_medusa,
369
370
371
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
372

OlivierDehaene's avatar
OlivierDehaene committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    if model_type == "cohere":
        if FLASH_ATTENTION:
            return FlashCohere(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    if model_type == "dbrx":
        if FLASH_ATTENTION:
            return FlashDbrx(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

417
    if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
418
419
        if sharded:
            if FLASH_ATTENTION:
420
                if config_dict.get("alibi", False):
421
422
423
424
425
                    raise NotImplementedError("sharded is not supported for this model")
                return FlashRWSharded(
                    model_id,
                    revision,
                    quantize=quantize,
426
                    use_medusa=use_medusa,
427
                    dtype=dtype,
428
429
                    trust_remote_code=trust_remote_code,
                )
430
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
431
        else:
432
            if FLASH_ATTENTION and not config_dict.get("alibi", False):
433
                return FlashRWSharded(
434
435
436
                    model_id,
                    revision,
                    quantize=quantize,
437
                    use_medusa=use_medusa,
438
                    dtype=dtype,
439
440
441
442
443
444
445
                    trust_remote_code=trust_remote_code,
                )
            else:
                return RW(
                    model_id,
                    revision,
                    quantize=quantize,
446
                    use_medusa=use_medusa,
447
                    dtype=dtype,
448
449
450
                    trust_remote_code=trust_remote_code,
                )

451
    if model_type == "mistral":
452
453
454
455
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
456
457
458
459
            return FlashMistral(
                model_id,
                revision,
                quantize=quantize,
460
                use_medusa=use_medusa,
461
462
463
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
464
465
466
467
468
469
470
471
472
473
474
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
475
476

    if model_type == "mixtral":
477
478
479
480
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
OlivierDehaene's avatar
OlivierDehaene committed
481
482
483
484
            return FlashMixtral(
                model_id,
                revision,
                quantize=quantize,
485
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
486
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
487
488
                trust_remote_code=trust_remote_code,
            )
OlivierDehaene's avatar
OlivierDehaene committed
489
490
491
492
493
494
495
496
497
498
499
500
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

OlivierDehaene's avatar
OlivierDehaene committed
501
502
503
504
505
506
507
508
509
510
    if model_type == "starcoder2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashStarcoder2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(
                FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
            )
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )

    if model_type == "qwen2":
        sliding_window = config_dict.get("sliding_window", -1)
        if (
            (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
        ) or HAS_FLASH_ATTN_V2_CUDA:
            return FlashQwen2(
                model_id,
                revision,
                quantize=quantize,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        elif sharded:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
        else:
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
548
549
                trust_remote_code=trust_remote_code,
            )
550
551

    if model_type == "opt":
552
        return OPTSharded(
553
554
555
            model_id,
            revision,
            quantize=quantize,
556
            use_medusa=use_medusa,
557
558
            dtype=dtype,
            trust_remote_code=trust_remote_code,
559
        )
560

561
    if model_type == "t5":
562
563
564
565
        return T5Sharded(
            model_id,
            revision,
            quantize=quantize,
566
            use_medusa=use_medusa,
567
            dtype=dtype,
568
569
            trust_remote_code=trust_remote_code,
        )
570
    if model_type == "idefics":
571
        if FLASH_ATTENTION:
OlivierDehaene's avatar
OlivierDehaene committed
572
573
574
575
            return IDEFICSSharded(
                model_id,
                revision,
                quantize=quantize,
576
                use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
577
578
579
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
580
581
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
582

583
584
585
586
587
588
589
590
591
592
593
594
595
    if model_type == "llava_next":
        if FLASH_ATTENTION:
            return LlavaNext(
                model_id,
                revision,
                quantize=quantize,
                use_medusa=use_medusa,
                dtype=dtype,
                trust_remote_code=trust_remote_code,
            )
        else:
            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

596
    if sharded:
597
        raise NotImplementedError("sharded is not supported for AutoModel")
598
    if quantize == "gptq":
599
        raise NotImplementedError(
600
601
            "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
        )
602
    if quantize == "awq":
603
        raise NotImplementedError("awq quantization is not supported for AutoModel")
Nicolas Patry's avatar
Nicolas Patry committed
604
    elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
605
        raise NotImplementedError("4bit quantization is not supported for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
606
    elif quantize == "eetq":
607
        raise NotImplementedError("Eetq quantization is not supported for AutoModel")
608
    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
609
        return CausalLM(
610
611
612
            model_id,
            revision,
            quantize=quantize,
613
            use_medusa=use_medusa,
614
615
            dtype=dtype,
            trust_remote_code=trust_remote_code,
616
        )
617
    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
618
        return Seq2SeqLM(
619
620
621
            model_id,
            revision,
            quantize=quantize,
622
            use_medusa=use_medusa,
623
624
            dtype=dtype,
            trust_remote_code=trust_remote_code,
625
626
        )

627
    auto_map = config_dict.get("auto_map", None)
628
629
630
631
632
633
    if trust_remote_code and auto_map is not None:
        if "AutoModelForCausalLM" in auto_map.keys():
            return CausalLM(
                model_id,
                revision,
                quantize=quantize,
634
                use_medusa=use_medusa,
635
                dtype=dtype,
636
637
                trust_remote_code=trust_remote_code,
            )
638
        if "AutoModelForSeq2SeqLM" in auto_map.keys():
639
640
641
642
            return Seq2SeqLM(
                model_id,
                revision,
                quantize=quantize,
643
                use_medusa=use_medusa,
644
                dtype=dtype,
645
646
                trust_remote_code=trust_remote_code,
            )
647
648

    raise ValueError(f"Unsupported model type {model_type}")