registry.py 46.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
38
39
40
41
42
43
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
else:
    AttnTypeStr = Any
    PoolingTypeStr = Any


44
45
46
47
48
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
49
    requires_raw_input_tokens,
50
    supports_cross_encoding,
51
    supports_mamba_prefix_caching,
52
53
54
55
56
57
58
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
59
    get_attn_type,
60
61
62
63
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
64
65
66

logger = init_logger(__name__)

67
68
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
69
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
70
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
71
72
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
73
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
74
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
75
76
77
78
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
79
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
80
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
81
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
82
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
83
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
84
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
85
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
86
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
87
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
89
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
90
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
91
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
92
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
93
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
94
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
95
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
96
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
97
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
98
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
99
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
100
101
102
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
103
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
104
105
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
106
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
107
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
108
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
109
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
110
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
111
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
112
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
113
114
115
116
117
118
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
119
120
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
121
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
122
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
123
124
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
125
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
126
127
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
128
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
129
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
130
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
131
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
132
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
133
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
134
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
135
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
136
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
137
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
138
139
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
140
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
141
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
142
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
143
144
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
145
146
147
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
148
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
149
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
150
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
151
152
153
154
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
155
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
156
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
157
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
158
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
159
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
160
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
161
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
162
163
164
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
165
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
166
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
167
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
168
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
169
170
171
172
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
173
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
174
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
175
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
176
177
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
178
179
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
180
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
181
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
182
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
183
184
185
186
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
187
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
188
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
189
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
190
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
191
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
192
193
194
}

_EMBEDDING_MODELS = {
195
    # [Text-only]
196
    "BertModel": ("bert", "BertEmbeddingModel"),
197
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
198
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
199
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
200
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
201
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
202
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
203
    "GritLM": ("gritlm", "GritLM"),
204
205
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
206
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
207
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
208
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
209
    "LlamaModel": ("llama", "LlamaForCausalLM"),
210
211
    **{
        # Multiple models share the same architecture, so we include them all
212
213
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
214
215
        if arch == "LlamaForCausalLM"
    },
216
    "MistralModel": ("llama", "LlamaForCausalLM"),
217
    "ModernBertModel": ("modernbert", "ModernBertModel"),
218
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
219
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
220
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
221
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
222
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
223
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
224
225
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
226
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
227
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
228
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
229
    # [Multimodal]
230
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
231
232
233
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
234
    ),
235
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
236
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
237
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
238
239
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
240
    # models for the time being.
241
242
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
243
244
}

245
246
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
247
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
248
249
250
251
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
252
253
254
255
256
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
257
258
259
260
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
261
262
263
264
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
265
266
267
268
269
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
270
271
}

272
_MULTIMODAL_MODELS = {
273
    # [Decoder-only]
274
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
275
276
277
278
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
279
280
281
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
282
    ),
283
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
284
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
285
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
286
287
288
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
289
    ),
290
291
292
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
293
    ),
294
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
295
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
296
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
297
298
299
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
300
    ),
301
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
302
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
303
304
305
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
306
    ),
307
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
308
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
309
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
310
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
311
312
313
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
314
    ),
315
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
316
317
318
319
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
320
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
321
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
322
323
324
325
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
326
327
328
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
329
    ),
330
331
332
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
333
    ),
334
335
336
337
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
338
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
339
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
340
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
341
342
343
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
344
    ),
345
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
346
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
347
348
349
350
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
351
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
352
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
353
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
354
355
356
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
357
    ),
358
359
360
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
361
    ),
362
363
364
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
365
    ),
366
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
367
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
368
369
370
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
371
    ),
372
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
373
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
374
375
376
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
377
    ),
378
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
379
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
380
    "Ovis": ("ovis", "Ovis"),
381
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
382
383
384
385
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
386
387
388
389
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
390
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
391
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
392
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
393
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
394
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
395
396
397
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
398
    ),
399
400
401
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
402
    ),
403
404
405
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
406
    ),
407
408
409
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
410
    ),
411
412
413
414
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
415
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
416
417
418
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
419
    ),
420
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
421
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
422
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
423
424
425
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
426
    ),
427
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
428
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
429
    "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"),  # noqa: E501
430
    # [Encoder-decoder]
431
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
432
}
433
434

_SPECULATIVE_DECODING_MODELS = {
435
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
436
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
437
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
438
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
439
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
440
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
441
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
442
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
443
444
445
446
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
447
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
448
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
449
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
450
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
451
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
452
    "MedusaModel": ("medusa", "Medusa"),
453
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
454
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
455
456
457
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
458
}
459

460
_TRANSFORMERS_SUPPORTED_MODELS = {
461
462
463
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
464
465
466
467
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
468
469
470
}

_TRANSFORMERS_BACKEND_MODELS = {
471
    # Text generation models
472
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
491
    "TransformersForSequenceClassification": (
492
        "transformers",
493
        "TransformersForSequenceClassification",
494
    ),
495
    "TransformersMoEForSequenceClassification": (
496
        "transformers",
497
        "TransformersMoEForSequenceClassification",
498
    ),
499
500
501
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
502
    ),
503
}
504

505
_VLLM_MODELS = {
506
    **_TEXT_GENERATION_MODELS,
507
    **_EMBEDDING_MODELS,
508
    **_CROSS_ENCODER_MODELS,
509
    **_MULTIMODAL_MODELS,
510
    **_SPECULATIVE_DECODING_MODELS,
511
512
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
513
514
}

515
516
517
518
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
519
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
520

521
_PREVIOUSLY_SUPPORTED_MODELS = {
522
    "MotifForCausalLM": "0.10.2",
523
    "Phi3SmallForCausalLM": "0.9.2",
524
    "Phi4FlashForCausalLM": "0.10.2",
525
    "Phi4MultimodalForCausalLM": "0.12.0",
526
527
528
529
530
531
532
533
534
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
535

536

537
538
@dataclass(frozen=True)
class _ModelInfo:
539
    architecture: str
540
    is_text_generation_model: bool
541
    is_pooling_model: bool
542
543
    attn_type: AttnTypeStr
    default_pooling_type: PoolingTypeStr
544
    supports_cross_encoding: bool
545
    supports_multimodal: bool
546
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
547
    requires_raw_input_tokens: bool
548
    supports_multimodal_encoder_tp_data: bool
549
    supports_pp: bool
550
551
    has_inner_state: bool
    is_attention_free: bool
552
    is_hybrid: bool
553
    has_noops: bool
554
    supports_mamba_prefix_caching: bool
555
    supports_transcription: bool
556
    supports_transcription_only: bool
557
558

    @staticmethod
559
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
560
        return _ModelInfo(
561
            architecture=model.__name__,
562
            is_text_generation_model=is_text_generation_model(model),
563
            is_pooling_model=is_pooling_model(model),
564
            default_pooling_type=get_default_pooling_type(model),
565
            attn_type=get_attn_type(model),
566
            supports_cross_encoding=supports_cross_encoding(model),
567
            supports_multimodal=supports_multimodal(model),
568
569
570
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
571
            requires_raw_input_tokens=requires_raw_input_tokens(model),
572
573
574
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
575
            supports_pp=supports_pp(model),
576
577
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
578
            is_hybrid=is_hybrid(model),
579
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
580
            supports_transcription=supports_transcription(model),
581
582
583
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
584
            has_noops=has_noops(model),
585
        )
586
587


588
589
590
591
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
592

593
    @abstractmethod
594
    def load_model_cls(self) -> type[nn.Module]:
595
        raise NotImplementedError
596
597


598
599
600
601
602
603
604
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
605
    model_cls: type[nn.Module]
606
607

    @staticmethod
608
    def from_model_cls(model_cls: type[nn.Module]):
609
610
611
612
613
614
615
616
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

617
    def load_model_cls(self) -> type[nn.Module]:
618
619
620
621
622
623
624
625
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
626

627
628
629
    module_name: str
    class_name: str

630
631
632
633
634
635
636
637
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

638
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
639
640
        try:
            try:
641
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
642
643
644
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
645
                logger.debug(
646
                    "Cached model info file for class %s.%s not found",
647
648
649
                    self.module_name,
                    self.class_name,
                )
650
651
652
                return None

            if mi_dict["hash"] != module_hash:
653
                logger.debug(
654
                    "Cached model info file for class %s.%s is stale",
655
656
657
                    self.module_name,
                    self.class_name,
                )
658
659
660
661
662
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
663
            logger.debug(
664
                "Cached model info for class %s.%s error. ",
665
666
667
                self.module_name,
                self.class_name,
            )
668
669
            return None

670
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
671
672
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
673

674
675
676
677
678
679
680
681
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
682
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
683
684
685
686
687
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
688
    def inspect_model_cls(self) -> _ModelInfo:
689
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
690
        module_hash = None
691

692
693
        if model_path.exists():
            with open(model_path, "rb") as f:
694
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
695
696
697

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
698
                logger.debug(
699
                    "Loaded model info for class %s.%s from cache",
700
701
702
                    self.module_name,
                    self.class_name,
                )
703
704
                return mi
            else:
705
                logger.debug(
706
                    "Cache model info for class %s.%s miss. Loading model instead.",
707
708
709
                    self.module_name,
                    self.class_name,
                )
710
711
712

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
713
714
715
716
717
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
718
719

        # save cache file
720
721
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
722
723

        return mi
724

725
    def load_model_cls(self) -> type[nn.Module]:
726
727
728
729
730
731
732
733
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
734
) -> type[nn.Module] | None:
735
    from vllm.platforms import current_platform
736

737
    current_platform.verify_model_arch(model_arch)
738
739
740
    try:
        return model.load_model_cls()
    except Exception:
741
        logger.exception("Error in loading model architecture '%s'", model_arch)
742
        return None
743
744


745
746
747
748
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
749
) -> _ModelInfo | None:
750
751
752
    try:
        return model.inspect_model_cls()
    except Exception:
753
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
754
        return None
755
756


757
758
759
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
760
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
761

762
    def get_supported_archs(self) -> Set[str]:
763
        return self.models.keys()
764

765
766
767
    def register_model(
        self,
        model_arch: str,
768
        model_cls: type[nn.Module] | str,
769
    ) -> None:
770
771
772
        """
        Register an external model to be used in vLLM.

773
        `model_cls` can be either:
774

775
        - A [`torch.nn.Module`][] class directly referencing the model.
776
        - A string in the format `<module>:<class>` which can be used to
777
778
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
779
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
780
        """
781
782
783
784
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

785
        if model_arch in self.models:
786
787
            logger.warning(
                "Model architecture %s is already registered, and will be "
788
789
790
791
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
792
793
794
795
796
797

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
798

799
            model = _LazyRegisteredModel(*split_str)
800
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
801
            model = _RegisteredModel.from_model_cls(model_cls)
802
        else:
803
804
805
806
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
807
            raise TypeError(msg)
808

809
        self.models[model_arch] = model
810

811
    def _raise_for_unsupported(self, architectures: list[str]):
812
        all_supported_archs = self.get_supported_archs()
813

814
815
816
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
817
818
                "to be inspected. Please check the logs for more details."
            )
819

820
821
822
823
824
825
826
827
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
828
829
                    "use this model architecture."
                )
830

831
832
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
833
834
            f"Supported architectures: {all_supported_archs}"
        )
835

836
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
837
838
        if model_arch not in self.models:
            return None
839

840
        return _try_load_model_cls(model_arch, self.models[model_arch])
841

842
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
843
844
        if model_arch not in self.models:
            return None
845

846
847
848
849
850
851
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
852
    ) -> str | None:
853
854
855
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

856
857
858
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
892
                if model_config.model_impl != "transformers":
893
894
895
896
897
898
899
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
900
901
                    "'auto_map' (relevant if the model is custom)."
                )
902
903

        if not model_module.is_backend_compatible():
904
            if model_config.model_impl != "transformers":
905
                return None
906

907
908
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
909
910
                "is not compatible with vLLM."
            )
911

912
        return model_config._get_transformers_backend_cls()
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
939

940
941
    def inspect_model_cls(
        self,
942
        architectures: str | list[str],
943
        model_config: ModelConfig,
944
    ) -> tuple[_ModelInfo, str]:
945
946
        if isinstance(architectures, str):
            architectures = [architectures]
947
948
        if not architectures:
            raise ValueError("No model architectures are specified")
949
950

        # Require transformers impl
951
        if model_config.model_impl == "transformers":
952
            arch = self._try_resolve_transformers(architectures[0], model_config)
953
954
955
956
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
957
        elif model_config.model_impl == "terratorch":
958
959
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
960

961
        # Fallback to transformers impl (after resolving convert_type)
962
963
964
965
966
967
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
968
969
970
971
972
973
974
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
975
            model_info = self._try_inspect_model_cls(normalized_arch)
976
            if model_info is not None:
977
                return (model_info, arch)
978

979
        # Fallback to transformers impl (before resolving runner_type)
980
981
982
983
984
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
985
986
987
988
989
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

990
        return self._raise_for_unsupported(architectures)
991

992
993
    def resolve_model_cls(
        self,
994
        architectures: str | list[str],
995
        model_config: ModelConfig,
996
    ) -> tuple[type[nn.Module], str]:
997
998
        if isinstance(architectures, str):
            architectures = [architectures]
999
1000
        if not architectures:
            raise ValueError("No model architectures are specified")
1001
1002

        # Require transformers impl
1003
        if model_config.model_impl == "transformers":
1004
            arch = self._try_resolve_transformers(architectures[0], model_config)
1005
1006
1007
1008
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1009
        elif model_config.model_impl == "terratorch":
1010
1011
1012
1013
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1014

1015
        # Fallback to transformers impl (after resolving convert_type)
1016
1017
1018
1019
1020
1021
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1022
1023
1024
1025
1026
1027
1028
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1029
            model_cls = self._try_load_model_cls(normalized_arch)
1030
1031
            if model_cls is not None:
                return (model_cls, arch)
1032

1033
        # Fallback to transformers impl (before resolving runner_type)
1034
1035
1036
1037
1038
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1039
1040
1041
1042
1043
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1044
        return self._raise_for_unsupported(architectures)
1045

1046
1047
    def is_text_generation_model(
        self,
1048
        architectures: str | list[str],
1049
        model_config: ModelConfig,
1050
    ) -> bool:
1051
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1052
        return model_cls.is_text_generation_model
1053

1054
    def is_pooling_model(
1055
        self,
1056
        architectures: str | list[str],
1057
        model_config: ModelConfig,
1058
    ) -> bool:
1059
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1060
        return model_cls.is_pooling_model
1061

1062
1063
    def is_cross_encoder_model(
        self,
1064
        architectures: str | list[str],
1065
        model_config: ModelConfig,
1066
    ) -> bool:
1067
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1068
        return model_cls.supports_cross_encoding
1069

1070
1071
    def is_multimodal_model(
        self,
1072
        architectures: str | list[str],
1073
        model_config: ModelConfig,
1074
    ) -> bool:
1075
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1076
        return model_cls.supports_multimodal
1077

1078
    def is_multimodal_raw_input_only_model(
1079
        self,
1080
        architectures: str | list[str],
1081
        model_config: ModelConfig,
1082
    ) -> bool:
1083
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1084
        return model_cls.supports_multimodal_raw_input_only
1085

1086
1087
    def is_pp_supported_model(
        self,
1088
        architectures: str | list[str],
1089
        model_config: ModelConfig,
1090
    ) -> bool:
1091
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1092
        return model_cls.supports_pp
1093

1094
1095
    def model_has_inner_state(
        self,
1096
        architectures: str | list[str],
1097
        model_config: ModelConfig,
1098
    ) -> bool:
1099
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1100
        return model_cls.has_inner_state
1101

1102
1103
    def is_attention_free_model(
        self,
1104
        architectures: str | list[str],
1105
        model_config: ModelConfig,
1106
    ) -> bool:
1107
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1108
        return model_cls.is_attention_free
1109

1110
1111
    def is_hybrid_model(
        self,
1112
        architectures: str | list[str],
1113
        model_config: ModelConfig,
1114
    ) -> bool:
1115
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1116
1117
        return model_cls.is_hybrid

1118
1119
    def is_noops_model(
        self,
1120
        architectures: str | list[str],
1121
        model_config: ModelConfig,
1122
    ) -> bool:
1123
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1124
1125
        return model_cls.has_noops

1126
1127
    def is_transcription_model(
        self,
1128
        architectures: str | list[str],
1129
        model_config: ModelConfig,
1130
    ) -> bool:
1131
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1132
1133
        return model_cls.supports_transcription

1134
1135
    def is_transcription_only_model(
        self,
1136
        architectures: str | list[str],
1137
        model_config: ModelConfig,
1138
    ) -> bool:
1139
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1140
1141
        return model_cls.supports_transcription_only

1142

1143
1144
1145
1146
1147
1148
1149
1150
1151
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1152
1153
1154
1155
1156

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1157
1158
1159
1160
1161
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1162
        # `cloudpickle` allows pickling lambda functions directly
1163
        import cloudpickle
1164

1165
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1166
1167
1168

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1169
1170
1171
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1172
1173
1174
1175
1176
1177

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1178
1179
1180
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1181

1182
        with open(output_filepath, "rb") as f:
1183
1184
1185
1186
1187
1188
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1189

1190
1191
1192
1193
1194
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1195
1196
1197

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1198
1199
1200


if __name__ == "__main__":
1201
    _run()