registry.py 47.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
64
65
    is_pooling_model,
    is_text_generation_model,
)
66
67
68

logger = init_logger(__name__)

69
70
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
71
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
72
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
73
74
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
75
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
76
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
77
78
79
80
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
81
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
82
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
83
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
84
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
85
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
86
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
87
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
88
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
90
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
91
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
92
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
93
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
94
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
95
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
97
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
98
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
99
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
100
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
101
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
102
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
103
104
105
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
106
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
107
108
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
109
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
110
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
111
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
112
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
113
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
115
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
116
117
118
119
120
121
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
122
123
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
124
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
125
126
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
127
128
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
129
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
130
131
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
132
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
133
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
134
135
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
136
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
137
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
138
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
139
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
140
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
141
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
142
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
143
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
144
145
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
146
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
147
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
148
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
149
150
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
151
152
153
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
154
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
155
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
156
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
157
158
159
160
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
161
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
162
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
163
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
164
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
165
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
166
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
167
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
168
169
170
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
171
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
172
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
173
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
174
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
175
176
177
178
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
179
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
180
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
181
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
182
183
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
184
185
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
186
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
187
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
188
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
189
190
191
192
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
193
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
194
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
195
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
196
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
197
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
198
199
200
}

_EMBEDDING_MODELS = {
201
    # [Text-only]
202
    "BertModel": ("bert", "BertEmbeddingModel"),
203
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
204
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
205
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
206
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
207
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
208
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
209
    "GritLM": ("gritlm", "GritLM"),
210
211
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
212
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
213
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
214
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
215
    "LlamaModel": ("llama", "LlamaForCausalLM"),
216
217
    **{
        # Multiple models share the same architecture, so we include them all
218
219
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
220
221
        if arch == "LlamaForCausalLM"
    },
222
    "MistralModel": ("llama", "LlamaForCausalLM"),
223
    "ModernBertModel": ("modernbert", "ModernBertModel"),
224
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
225
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
226
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
227
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
228
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
229
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
230
231
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
232
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
233
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
234
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
235
    # [Multimodal]
236
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
237
238
239
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
240
    ),
241
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
242
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
243
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
244
245
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
246
    # models for the time being.
247
248
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
249
250
}

251
252
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
253
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
254
255
256
257
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
258
259
260
261
262
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
263
264
265
266
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
267
268
269
270
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
271
272
273
274
275
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
276
277
}

278
_MULTIMODAL_MODELS = {
279
    # [Decoder-only]
280
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
281
282
283
284
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
285
286
287
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
288
    ),
289
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
290
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
291
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
292
293
294
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
295
    ),
296
297
298
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
299
    ),
300
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
301
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
302
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
303
304
305
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
306
    ),
307
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
308
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
309
310
311
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
312
    ),
313
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
314
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
315
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
316
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
317
318
319
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
320
    ),
321
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
322
323
324
325
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
326
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
327
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
328
329
330
331
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
332
333
334
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
335
    ),
336
337
338
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
339
    ),
340
341
342
343
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
344
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
345
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
346
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
347
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
348
349
350
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
351
    ),
352
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
353
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
354
355
356
357
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
358
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
359
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
360
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
361
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
362
363
364
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
365
    ),
366
367
368
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
369
    ),
370
371
372
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
373
    ),
374
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
375
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
376
377
378
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
379
    ),
380
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
381
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
382
383
384
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
385
    ),
386
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
387
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
388
    "Ovis": ("ovis", "Ovis"),
389
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
390
391
392
393
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
394
395
396
397
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
398
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
399
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
400
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
401
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
402
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
403
404
405
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
406
    ),
407
408
409
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
410
    ),
411
412
413
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
414
    ),
415
416
417
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
418
    ),
419
420
421
422
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
423
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
424
425
426
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
427
    ),
428
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
429
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
430
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
431
432
433
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
434
    ),
435
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
436
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
437
    "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"),  # noqa: E501
438
    # [Encoder-decoder]
439
440
441
442
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
443
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
444
}
445
446

_SPECULATIVE_DECODING_MODELS = {
447
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
448
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
449
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
450
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
451
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
452
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
453
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
454
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
455
456
457
458
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
459
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
460
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
461
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
462
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
463
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
464
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
465
    "MedusaModel": ("medusa", "Medusa"),
466
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
467
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
468
469
470
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
471
}
472

473
_TRANSFORMERS_SUPPORTED_MODELS = {
474
475
476
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
477
478
479
480
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
481
482
483
}

_TRANSFORMERS_BACKEND_MODELS = {
484
    # Text generation models
485
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
504
    "TransformersForSequenceClassification": (
505
        "transformers",
506
        "TransformersForSequenceClassification",
507
    ),
508
    "TransformersMoEForSequenceClassification": (
509
        "transformers",
510
        "TransformersMoEForSequenceClassification",
511
    ),
512
513
514
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
515
    ),
516
}
517

518
_VLLM_MODELS = {
519
    **_TEXT_GENERATION_MODELS,
520
    **_EMBEDDING_MODELS,
521
    **_CROSS_ENCODER_MODELS,
522
    **_MULTIMODAL_MODELS,
523
    **_SPECULATIVE_DECODING_MODELS,
524
525
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
526
527
}

528
529
530
531
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
532
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
533

534
_PREVIOUSLY_SUPPORTED_MODELS = {
535
    "MotifForCausalLM": "0.10.2",
536
    "Phi3SmallForCausalLM": "0.9.2",
537
    "Phi4FlashForCausalLM": "0.10.2",
538
    "Phi4MultimodalForCausalLM": "0.12.0",
539
540
541
542
543
544
545
546
547
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
548

549

550
551
@dataclass(frozen=True)
class _ModelInfo:
552
    architecture: str
553
    is_text_generation_model: bool
554
    is_pooling_model: bool
555
    attn_type: AttnTypeStr
556
557
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
558
    supports_cross_encoding: bool
559
    supports_multimodal: bool
560
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
561
    requires_raw_input_tokens: bool
562
    supports_multimodal_encoder_tp_data: bool
563
    supports_pp: bool
564
565
    has_inner_state: bool
    is_attention_free: bool
566
    is_hybrid: bool
567
    has_noops: bool
568
    supports_mamba_prefix_caching: bool
569
    supports_transcription: bool
570
    supports_transcription_only: bool
571
572

    @staticmethod
573
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
574
        return _ModelInfo(
575
            architecture=model.__name__,
576
            is_text_generation_model=is_text_generation_model(model),
577
            is_pooling_model=is_pooling_model(model),
578
579
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
580
            attn_type=get_attn_type(model),
581
            supports_cross_encoding=supports_cross_encoding(model),
582
            supports_multimodal=supports_multimodal(model),
583
584
585
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
586
            requires_raw_input_tokens=requires_raw_input_tokens(model),
587
588
589
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
590
            supports_pp=supports_pp(model),
591
592
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
593
            is_hybrid=is_hybrid(model),
594
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
595
            supports_transcription=supports_transcription(model),
596
597
598
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
599
            has_noops=has_noops(model),
600
        )
601
602


603
604
605
606
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
607

608
    @abstractmethod
609
    def load_model_cls(self) -> type[nn.Module]:
610
        raise NotImplementedError
611
612


613
614
615
616
617
618
619
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
620
    model_cls: type[nn.Module]
621
622

    @staticmethod
623
    def from_model_cls(model_cls: type[nn.Module]):
624
625
626
627
628
629
630
631
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

632
    def load_model_cls(self) -> type[nn.Module]:
633
634
635
636
637
638
639
640
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
641

642
643
644
    module_name: str
    class_name: str

645
646
647
648
649
650
651
652
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

653
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
654
655
        try:
            try:
656
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
657
658
659
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
660
                logger.debug(
661
                    "Cached model info file for class %s.%s not found",
662
663
664
                    self.module_name,
                    self.class_name,
                )
665
666
667
                return None

            if mi_dict["hash"] != module_hash:
668
                logger.debug(
669
                    "Cached model info file for class %s.%s is stale",
670
671
672
                    self.module_name,
                    self.class_name,
                )
673
674
675
676
677
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
678
            logger.debug(
679
                "Cached model info for class %s.%s error. ",
680
681
682
                self.module_name,
                self.class_name,
            )
683
684
            return None

685
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
686
687
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
688

689
690
691
692
693
694
695
696
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
697
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
698
699
700
701
702
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
703
    def inspect_model_cls(self) -> _ModelInfo:
704
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
705
        module_hash = None
706

707
708
        if model_path.exists():
            with open(model_path, "rb") as f:
709
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
710
711
712

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
713
                logger.debug(
714
                    "Loaded model info for class %s.%s from cache",
715
716
717
                    self.module_name,
                    self.class_name,
                )
718
719
                return mi
            else:
720
                logger.debug(
721
                    "Cache model info for class %s.%s miss. Loading model instead.",
722
723
724
                    self.module_name,
                    self.class_name,
                )
725
726
727

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
728
729
730
731
732
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
733
734

        # save cache file
735
736
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
737
738

        return mi
739

740
    def load_model_cls(self) -> type[nn.Module]:
741
742
743
744
745
746
747
748
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
749
) -> type[nn.Module] | None:
750
    from vllm.platforms import current_platform
751

752
    current_platform.verify_model_arch(model_arch)
753
754
755
    try:
        return model.load_model_cls()
    except Exception:
756
        logger.exception("Error in loading model architecture '%s'", model_arch)
757
        return None
758
759


760
761
762
763
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
764
) -> _ModelInfo | None:
765
766
767
    try:
        return model.inspect_model_cls()
    except Exception:
768
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
769
        return None
770
771


772
773
774
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
775
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
776

777
    def get_supported_archs(self) -> Set[str]:
778
        return self.models.keys()
779

780
781
782
    def register_model(
        self,
        model_arch: str,
783
        model_cls: type[nn.Module] | str,
784
    ) -> None:
785
786
787
        """
        Register an external model to be used in vLLM.

788
        `model_cls` can be either:
789

790
        - A [`torch.nn.Module`][] class directly referencing the model.
791
        - A string in the format `<module>:<class>` which can be used to
792
793
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
794
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
795
        """
796
797
798
799
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

800
        if model_arch in self.models:
801
802
            logger.warning(
                "Model architecture %s is already registered, and will be "
803
804
805
806
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
807
808
809
810
811
812

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
813

814
            model = _LazyRegisteredModel(*split_str)
815
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
816
            model = _RegisteredModel.from_model_cls(model_cls)
817
        else:
818
819
820
821
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
822
            raise TypeError(msg)
823

824
        self.models[model_arch] = model
825

826
    def _raise_for_unsupported(self, architectures: list[str]):
827
        all_supported_archs = self.get_supported_archs()
828

829
830
831
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
832
833
                "to be inspected. Please check the logs for more details."
            )
834

835
836
837
838
839
840
841
842
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
843
844
                    "use this model architecture."
                )
845

846
847
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
848
849
            f"Supported architectures: {all_supported_archs}"
        )
850

851
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
852
853
        if model_arch not in self.models:
            return None
854

855
        return _try_load_model_cls(model_arch, self.models[model_arch])
856

857
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
858
859
        if model_arch not in self.models:
            return None
860

861
862
863
864
865
866
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
867
    ) -> str | None:
868
869
870
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

871
872
873
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
907
                if model_config.model_impl != "transformers":
908
909
910
911
912
913
914
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
915
916
                    "'auto_map' (relevant if the model is custom)."
                )
917
918

        if not model_module.is_backend_compatible():
919
            if model_config.model_impl != "transformers":
920
                return None
921

922
923
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
924
925
                "is not compatible with vLLM."
            )
926

927
        return model_config._get_transformers_backend_cls()
928

929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
954

955
956
    def inspect_model_cls(
        self,
957
        architectures: str | list[str],
958
        model_config: ModelConfig,
959
    ) -> tuple[_ModelInfo, str]:
960
961
        if isinstance(architectures, str):
            architectures = [architectures]
962
963
        if not architectures:
            raise ValueError("No model architectures are specified")
964
965

        # Require transformers impl
966
        if model_config.model_impl == "transformers":
967
            arch = self._try_resolve_transformers(architectures[0], model_config)
968
969
970
971
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
972
        elif model_config.model_impl == "terratorch":
973
974
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
975

976
        # Fallback to transformers impl (after resolving convert_type)
977
978
979
980
981
982
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
983
984
985
986
987
988
989
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
990
            model_info = self._try_inspect_model_cls(normalized_arch)
991
            if model_info is not None:
992
                return (model_info, arch)
993

994
        # Fallback to transformers impl (before resolving runner_type)
995
996
997
998
999
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1000
1001
1002
1003
1004
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1005
        return self._raise_for_unsupported(architectures)
1006

1007
1008
    def resolve_model_cls(
        self,
1009
        architectures: str | list[str],
1010
        model_config: ModelConfig,
1011
    ) -> tuple[type[nn.Module], str]:
1012
1013
        if isinstance(architectures, str):
            architectures = [architectures]
1014
1015
        if not architectures:
            raise ValueError("No model architectures are specified")
1016
1017

        # Require transformers impl
1018
        if model_config.model_impl == "transformers":
1019
            arch = self._try_resolve_transformers(architectures[0], model_config)
1020
1021
1022
1023
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1024
        elif model_config.model_impl == "terratorch":
1025
1026
1027
1028
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1029

1030
        # Fallback to transformers impl (after resolving convert_type)
1031
1032
1033
1034
1035
1036
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1037
1038
1039
1040
1041
1042
1043
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1044
            model_cls = self._try_load_model_cls(normalized_arch)
1045
1046
            if model_cls is not None:
                return (model_cls, arch)
1047

1048
        # Fallback to transformers impl (before resolving runner_type)
1049
1050
1051
1052
1053
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1054
1055
1056
1057
1058
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1059
        return self._raise_for_unsupported(architectures)
1060

1061
1062
    def is_text_generation_model(
        self,
1063
        architectures: str | list[str],
1064
        model_config: ModelConfig,
1065
    ) -> bool:
1066
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1067
        return model_cls.is_text_generation_model
1068

1069
    def is_pooling_model(
1070
        self,
1071
        architectures: str | list[str],
1072
        model_config: ModelConfig,
1073
    ) -> bool:
1074
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1075
        return model_cls.is_pooling_model
1076

1077
1078
    def is_cross_encoder_model(
        self,
1079
        architectures: str | list[str],
1080
        model_config: ModelConfig,
1081
    ) -> bool:
1082
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1083
        return model_cls.supports_cross_encoding
1084

1085
1086
    def is_multimodal_model(
        self,
1087
        architectures: str | list[str],
1088
        model_config: ModelConfig,
1089
    ) -> bool:
1090
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1091
        return model_cls.supports_multimodal
1092

1093
    def is_multimodal_raw_input_only_model(
1094
        self,
1095
        architectures: str | list[str],
1096
        model_config: ModelConfig,
1097
    ) -> bool:
1098
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1099
        return model_cls.supports_multimodal_raw_input_only
1100

1101
1102
    def is_pp_supported_model(
        self,
1103
        architectures: str | list[str],
1104
        model_config: ModelConfig,
1105
    ) -> bool:
1106
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1107
        return model_cls.supports_pp
1108

1109
1110
    def model_has_inner_state(
        self,
1111
        architectures: str | list[str],
1112
        model_config: ModelConfig,
1113
    ) -> bool:
1114
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1115
        return model_cls.has_inner_state
1116

1117
1118
    def is_attention_free_model(
        self,
1119
        architectures: str | list[str],
1120
        model_config: ModelConfig,
1121
    ) -> bool:
1122
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1123
        return model_cls.is_attention_free
1124

1125
1126
    def is_hybrid_model(
        self,
1127
        architectures: str | list[str],
1128
        model_config: ModelConfig,
1129
    ) -> bool:
1130
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1131
1132
        return model_cls.is_hybrid

1133
1134
    def is_noops_model(
        self,
1135
        architectures: str | list[str],
1136
        model_config: ModelConfig,
1137
    ) -> bool:
1138
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1139
1140
        return model_cls.has_noops

1141
1142
    def is_transcription_model(
        self,
1143
        architectures: str | list[str],
1144
        model_config: ModelConfig,
1145
    ) -> bool:
1146
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1147
1148
        return model_cls.supports_transcription

1149
1150
    def is_transcription_only_model(
        self,
1151
        architectures: str | list[str],
1152
        model_config: ModelConfig,
1153
    ) -> bool:
1154
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1155
1156
        return model_cls.supports_transcription_only

1157

1158
1159
1160
1161
1162
1163
1164
1165
1166
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1167
1168
1169
1170
1171

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1172
1173
1174
1175
1176
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1177
        # `cloudpickle` allows pickling lambda functions directly
1178
        import cloudpickle
1179

1180
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1181
1182
1183

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1184
1185
1186
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1187
1188
1189
1190
1191
1192

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1193
1194
1195
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1196

1197
        with open(output_filepath, "rb") as f:
1198
1199
1200
1201
1202
1203
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1204

1205
1206
1207
1208
1209
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1210
1211
1212

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1213
1214
1215


if __name__ == "__main__":
1216
    _run()