registry.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
38
39
40
41
42
43
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
else:
    AttnTypeStr = Any
    PoolingTypeStr = Any


44
45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
50
    supports_mamba_prefix_caching,
51
52
53
54
55
56
57
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
58
    get_attn_type,
59
60
61
62
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
63
64
65

logger = init_logger(__name__)

66
67
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
68
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
69
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
70
71
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
72
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
73
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
74
75
76
77
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
78
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
79
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
80
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
81
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
82
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
83
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
84
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
85
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
86
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
87
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
88
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
89
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
90
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
91
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
92
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
93
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
94
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
95
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
96
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
97
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
98
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
99
100
101
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
102
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
103
104
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
105
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
106
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
107
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
108
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
109
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
110
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
111
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
112
113
114
115
116
117
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
118
119
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
120
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
121
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
122
123
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
124
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
125
126
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
127
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
128
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
129
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
130
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
131
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
132
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
133
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
134
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
135
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
136
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
137
138
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
139
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
140
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
141
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
142
143
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
144
145
146
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
147
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
148
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
149
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
150
151
152
153
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
154
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
155
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
156
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
157
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
158
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
159
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
160
161
162
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
163
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
164
165
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
166
167
168
169
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
170
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
171
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
172
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
173
174
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
175
176
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
177
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
178
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
179
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
180
181
182
183
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
184
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
185
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
186
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
187
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
188
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
189
190
191
}

_EMBEDDING_MODELS = {
192
    # [Text-only]
193
    "BertModel": ("bert", "BertEmbeddingModel"),
194
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
195
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
196
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
197
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
198
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
199
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
200
    "GritLM": ("gritlm", "GritLM"),
201
202
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
203
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
204
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
205
    "LlamaModel": ("llama", "LlamaForCausalLM"),
206
207
    **{
        # Multiple models share the same architecture, so we include them all
208
209
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
210
211
        if arch == "LlamaForCausalLM"
    },
212
    "MistralModel": ("llama", "LlamaForCausalLM"),
213
    "ModernBertModel": ("modernbert", "ModernBertModel"),
214
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
215
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
216
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
217
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
218
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
219
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
220
221
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
222
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
223
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
224
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
225
    # [Multimodal]
226
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
227
228
229
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
230
    ),
231
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
232
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
233
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
234
235
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
236
    # models for the time being.
237
238
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
239
240
}

241
242
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
243
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
244
245
246
247
248
249
250
251
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
252
253
254
255
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
256
257
258
259
260
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
261
    # [Auto-converted (see adapters.py)]
262
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
263
264
}

265
_MULTIMODAL_MODELS = {
266
    # [Decoder-only]
267
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
268
269
270
271
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
272
273
274
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
275
    ),
276
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
277
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
278
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
279
280
281
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
282
    ),
283
284
285
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
286
    ),
287
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
288
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
289
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
290
291
292
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
293
    ),
294
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
295
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
296
297
298
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
299
    ),
300
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
301
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
302
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
303
304
305
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
306
    ),
307
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
308
309
310
311
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
312
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
313
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
314
315
316
317
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
318
319
320
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
321
    ),
322
323
324
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
325
    ),
326
327
328
329
330
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
331
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
332
333
334
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
335
    ),
336
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
337
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
338
339
340
341
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
342
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
343
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
344
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
345
346
347
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
348
    ),
349
350
351
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
352
    ),
353
354
355
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
356
    ),
357
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
358
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
359
360
361
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
362
    ),
363
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
364
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
365
366
367
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
368
    ),
369
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
370
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
371
    "Ovis": ("ovis", "Ovis"),
372
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
373
374
375
376
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
377
378
379
380
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
381
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
382
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
383
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
384
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
385
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
386
387
388
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
389
    ),
390
391
392
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
393
    ),
394
395
396
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
397
    ),
398
399
400
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
401
    ),
402
403
404
405
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
406
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
407
408
409
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
410
    ),
411
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
412
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
413
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
414
415
416
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
417
    ),
418
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
419
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
420
    # [Encoder-decoder]
421
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
422
}
423
424

_SPECULATIVE_DECODING_MODELS = {
425
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
426
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
427
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
428
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
429
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
430
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
431
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
432
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
433
434
435
436
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
437
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
438
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
439
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
440
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
441
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
442
    "MedusaModel": ("medusa", "Medusa"),
443
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
444
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
445
446
447
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
448
}
449

450
_TRANSFORMERS_SUPPORTED_MODELS = {
451
452
453
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
454
455
456
457
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
458
459
460
}

_TRANSFORMERS_BACKEND_MODELS = {
461
    # Text generation models
462
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
481
    "TransformersForSequenceClassification": (
482
        "transformers",
483
        "TransformersForSequenceClassification",
484
    ),
485
    "TransformersMoEForSequenceClassification": (
486
        "transformers",
487
        "TransformersMoEForSequenceClassification",
488
    ),
489
490
491
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
492
    ),
493
}
494

495
_VLLM_MODELS = {
496
    **_TEXT_GENERATION_MODELS,
497
    **_EMBEDDING_MODELS,
498
    **_CROSS_ENCODER_MODELS,
499
    **_MULTIMODAL_MODELS,
500
    **_SPECULATIVE_DECODING_MODELS,
501
502
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
503
504
}

505
506
507
508
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
509
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
510

511
_PREVIOUSLY_SUPPORTED_MODELS = {
512
    "MotifForCausalLM": "0.10.2",
513
    "Phi3SmallForCausalLM": "0.9.2",
514
    "Phi4FlashForCausalLM": "0.10.2",
515
    "Phi4MultimodalForCausalLM": "0.12.0",
516
517
518
519
520
521
522
523
524
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
525

526

527
528
@dataclass(frozen=True)
class _ModelInfo:
529
    architecture: str
530
    is_text_generation_model: bool
531
    is_pooling_model: bool
532
533
    attn_type: AttnTypeStr
    default_pooling_type: PoolingTypeStr
534
    supports_cross_encoding: bool
535
    supports_multimodal: bool
536
    supports_multimodal_raw_input_only: bool
537
    supports_multimodal_encoder_tp_data: bool
538
    supports_pp: bool
539
540
    has_inner_state: bool
    is_attention_free: bool
541
    is_hybrid: bool
542
    has_noops: bool
543
    supports_mamba_prefix_caching: bool
544
    supports_transcription: bool
545
    supports_transcription_only: bool
546
547

    @staticmethod
548
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
549
        return _ModelInfo(
550
            architecture=model.__name__,
551
            is_text_generation_model=is_text_generation_model(model),
552
            is_pooling_model=is_pooling_model(model),
553
            default_pooling_type=get_default_pooling_type(model),
554
            attn_type=get_attn_type(model),
555
            supports_cross_encoding=supports_cross_encoding(model),
556
            supports_multimodal=supports_multimodal(model),
557
558
559
560
561
562
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
563
            supports_pp=supports_pp(model),
564
565
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
566
            is_hybrid=is_hybrid(model),
567
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
568
            supports_transcription=supports_transcription(model),
569
570
571
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
572
            has_noops=has_noops(model),
573
        )
574
575


576
577
578
579
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
580

581
    @abstractmethod
582
    def load_model_cls(self) -> type[nn.Module]:
583
        raise NotImplementedError
584
585


586
587
588
589
590
591
592
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
593
    model_cls: type[nn.Module]
594
595

    @staticmethod
596
    def from_model_cls(model_cls: type[nn.Module]):
597
598
599
600
601
602
603
604
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

605
    def load_model_cls(self) -> type[nn.Module]:
606
607
608
609
610
611
612
613
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
614

615
616
617
    module_name: str
    class_name: str

618
619
620
621
622
623
624
625
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

626
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
627
628
        try:
            try:
629
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
630
631
632
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
633
                logger.debug(
634
                    "Cached model info file for class %s.%s not found",
635
636
637
                    self.module_name,
                    self.class_name,
                )
638
639
640
                return None

            if mi_dict["hash"] != module_hash:
641
                logger.debug(
642
                    "Cached model info file for class %s.%s is stale",
643
644
645
                    self.module_name,
                    self.class_name,
                )
646
647
648
649
650
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
651
            logger.debug(
652
                "Cached model info for class %s.%s error. ",
653
654
655
                self.module_name,
                self.class_name,
            )
656
657
            return None

658
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
659
660
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
661

662
663
664
665
666
667
668
669
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
670
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
671
672
673
674
675
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
676
    def inspect_model_cls(self) -> _ModelInfo:
677
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
678
        module_hash = None
679

680
681
        if model_path.exists():
            with open(model_path, "rb") as f:
682
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
683
684
685

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
686
                logger.debug(
687
                    "Loaded model info for class %s.%s from cache",
688
689
690
                    self.module_name,
                    self.class_name,
                )
691
692
                return mi
            else:
693
                logger.debug(
694
                    "Cache model info for class %s.%s miss. Loading model instead.",
695
696
697
                    self.module_name,
                    self.class_name,
                )
698
699
700

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
701
702
703
704
705
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
706
707

        # save cache file
708
709
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
710
711

        return mi
712

713
    def load_model_cls(self) -> type[nn.Module]:
714
715
716
717
718
719
720
721
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
722
) -> type[nn.Module] | None:
723
    from vllm.platforms import current_platform
724

725
    current_platform.verify_model_arch(model_arch)
726
727
728
    try:
        return model.load_model_cls()
    except Exception:
729
        logger.exception("Error in loading model architecture '%s'", model_arch)
730
        return None
731
732


733
734
735
736
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
737
) -> _ModelInfo | None:
738
739
740
    try:
        return model.inspect_model_cls()
    except Exception:
741
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
742
        return None
743
744


745
746
747
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
748
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
749

750
    def get_supported_archs(self) -> Set[str]:
751
        return self.models.keys()
752

753
754
755
    def register_model(
        self,
        model_arch: str,
756
        model_cls: type[nn.Module] | str,
757
    ) -> None:
758
759
760
        """
        Register an external model to be used in vLLM.

761
        `model_cls` can be either:
762

763
        - A [`torch.nn.Module`][] class directly referencing the model.
764
        - A string in the format `<module>:<class>` which can be used to
765
766
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
767
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
768
        """
769
770
771
772
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

773
        if model_arch in self.models:
774
775
            logger.warning(
                "Model architecture %s is already registered, and will be "
776
777
778
779
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
780
781
782
783
784
785

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
786

787
            model = _LazyRegisteredModel(*split_str)
788
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
789
            model = _RegisteredModel.from_model_cls(model_cls)
790
        else:
791
792
793
794
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
795
            raise TypeError(msg)
796

797
        self.models[model_arch] = model
798

799
    def _raise_for_unsupported(self, architectures: list[str]):
800
        all_supported_archs = self.get_supported_archs()
801

802
803
804
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
805
806
                "to be inspected. Please check the logs for more details."
            )
807

808
809
810
811
812
813
814
815
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
816
817
                    "use this model architecture."
                )
818

819
820
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
821
822
            f"Supported architectures: {all_supported_archs}"
        )
823

824
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
825
826
        if model_arch not in self.models:
            return None
827

828
        return _try_load_model_cls(model_arch, self.models[model_arch])
829

830
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
831
832
        if model_arch not in self.models:
            return None
833

834
835
836
837
838
839
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
840
    ) -> str | None:
841
842
843
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

844
845
846
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
880
                if model_config.model_impl != "transformers":
881
882
883
884
885
886
887
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
888
889
                    "'auto_map' (relevant if the model is custom)."
                )
890
891

        if not model_module.is_backend_compatible():
892
            if model_config.model_impl != "transformers":
893
                return None
894

895
896
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
897
898
                "is not compatible with vLLM."
            )
899

900
        return model_config._get_transformers_backend_cls()
901

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
927

928
929
    def inspect_model_cls(
        self,
930
        architectures: str | list[str],
931
        model_config: ModelConfig,
932
    ) -> tuple[_ModelInfo, str]:
933
934
        if isinstance(architectures, str):
            architectures = [architectures]
935
936
        if not architectures:
            raise ValueError("No model architectures are specified")
937
938

        # Require transformers impl
939
        if model_config.model_impl == "transformers":
940
            arch = self._try_resolve_transformers(architectures[0], model_config)
941
942
943
944
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
945
        elif model_config.model_impl == "terratorch":
946
947
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
948

949
        # Fallback to transformers impl (after resolving convert_type)
950
951
952
953
954
955
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
956
957
958
959
960
961
962
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
963
            model_info = self._try_inspect_model_cls(normalized_arch)
964
            if model_info is not None:
965
                return (model_info, arch)
966

967
        # Fallback to transformers impl (before resolving runner_type)
968
969
970
971
972
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
973
974
975
976
977
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

978
        return self._raise_for_unsupported(architectures)
979

980
981
    def resolve_model_cls(
        self,
982
        architectures: str | list[str],
983
        model_config: ModelConfig,
984
    ) -> tuple[type[nn.Module], str]:
985
986
        if isinstance(architectures, str):
            architectures = [architectures]
987
988
        if not architectures:
            raise ValueError("No model architectures are specified")
989
990

        # Require transformers impl
991
        if model_config.model_impl == "transformers":
992
            arch = self._try_resolve_transformers(architectures[0], model_config)
993
994
995
996
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
997
        elif model_config.model_impl == "terratorch":
998
999
1000
1001
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1002

1003
        # Fallback to transformers impl (after resolving convert_type)
1004
1005
1006
1007
1008
1009
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1010
1011
1012
1013
1014
1015
1016
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1017
            model_cls = self._try_load_model_cls(normalized_arch)
1018
1019
            if model_cls is not None:
                return (model_cls, arch)
1020

1021
        # Fallback to transformers impl (before resolving runner_type)
1022
1023
1024
1025
1026
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1027
1028
1029
1030
1031
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1032
        return self._raise_for_unsupported(architectures)
1033

1034
1035
    def is_text_generation_model(
        self,
1036
        architectures: str | list[str],
1037
        model_config: ModelConfig,
1038
    ) -> bool:
1039
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1040
        return model_cls.is_text_generation_model
1041

1042
    def is_pooling_model(
1043
        self,
1044
        architectures: str | list[str],
1045
        model_config: ModelConfig,
1046
    ) -> bool:
1047
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1048
        return model_cls.is_pooling_model
1049

1050
1051
    def is_cross_encoder_model(
        self,
1052
        architectures: str | list[str],
1053
        model_config: ModelConfig,
1054
    ) -> bool:
1055
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1056
        return model_cls.supports_cross_encoding
1057

1058
1059
    def is_multimodal_model(
        self,
1060
        architectures: str | list[str],
1061
        model_config: ModelConfig,
1062
    ) -> bool:
1063
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1064
        return model_cls.supports_multimodal
1065

1066
    def is_multimodal_raw_input_only_model(
1067
        self,
1068
        architectures: str | list[str],
1069
        model_config: ModelConfig,
1070
    ) -> bool:
1071
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1072
        return model_cls.supports_multimodal_raw_input_only
1073

1074
1075
    def is_pp_supported_model(
        self,
1076
        architectures: str | list[str],
1077
        model_config: ModelConfig,
1078
    ) -> bool:
1079
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1080
        return model_cls.supports_pp
1081

1082
1083
    def model_has_inner_state(
        self,
1084
        architectures: str | list[str],
1085
        model_config: ModelConfig,
1086
    ) -> bool:
1087
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1088
        return model_cls.has_inner_state
1089

1090
1091
    def is_attention_free_model(
        self,
1092
        architectures: str | list[str],
1093
        model_config: ModelConfig,
1094
    ) -> bool:
1095
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1096
        return model_cls.is_attention_free
1097

1098
1099
    def is_hybrid_model(
        self,
1100
        architectures: str | list[str],
1101
        model_config: ModelConfig,
1102
    ) -> bool:
1103
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1104
1105
        return model_cls.is_hybrid

1106
1107
    def is_noops_model(
        self,
1108
        architectures: str | list[str],
1109
        model_config: ModelConfig,
1110
    ) -> bool:
1111
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1112
1113
        return model_cls.has_noops

1114
1115
    def is_transcription_model(
        self,
1116
        architectures: str | list[str],
1117
        model_config: ModelConfig,
1118
    ) -> bool:
1119
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1120
1121
        return model_cls.supports_transcription

1122
1123
    def is_transcription_only_model(
        self,
1124
        architectures: str | list[str],
1125
        model_config: ModelConfig,
1126
    ) -> bool:
1127
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1128
1129
        return model_cls.supports_transcription_only

1130

1131
1132
1133
1134
1135
1136
1137
1138
1139
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1140
1141
1142
1143
1144

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1145
1146
1147
1148
1149
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1150
        # `cloudpickle` allows pickling lambda functions directly
1151
        import cloudpickle
1152

1153
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1154
1155
1156

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1157
1158
1159
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1160
1161
1162
1163
1164
1165

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1166
1167
1168
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1169

1170
        with open(output_filepath, "rb") as f:
1171
1172
1173
1174
1175
1176
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1177

1178
1179
1180
1181
1182
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1183
1184
1185

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1186
1187
1188


if __name__ == "__main__":
1189
    _run()