registry.py 47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
64
65
    is_pooling_model,
    is_text_generation_model,
)
66
67
68

logger = init_logger(__name__)

69
70
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
71
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
72
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
73
74
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
75
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
76
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
77
78
79
80
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
81
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
82
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
83
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
84
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
85
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
86
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
87
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
88
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
90
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
91
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
92
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
93
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
94
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
95
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
97
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
98
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
99
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
100
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
101
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
102
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
103
104
105
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
106
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
107
108
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
109
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
110
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
111
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
112
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
113
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
115
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
116
117
118
119
120
121
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
122
123
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
124
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
125
126
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
127
128
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
129
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
130
131
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
132
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
133
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
134
135
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
136
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
137
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
138
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
139
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
140
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
141
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
142
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
143
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
144
145
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
146
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
147
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
148
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
149
150
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
151
152
153
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
154
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
155
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
156
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
157
158
159
160
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
161
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
162
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
163
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
164
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
165
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
166
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
167
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
168
169
170
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
171
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
172
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
173
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
174
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
175
176
177
178
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
179
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
180
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
181
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
182
183
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
184
185
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
186
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
187
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
188
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
189
190
191
192
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
193
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
194
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
195
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
196
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
197
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
198
199
200
}

_EMBEDDING_MODELS = {
201
    # [Text-only]
202
    "BertModel": ("bert", "BertEmbeddingModel"),
203
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
204
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
205
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
206
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
207
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
208
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
209
    "GritLM": ("gritlm", "GritLM"),
210
211
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
212
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
213
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
214
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
215
    "LlamaModel": ("llama", "LlamaForCausalLM"),
216
217
    **{
        # Multiple models share the same architecture, so we include them all
218
219
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
220
221
        if arch == "LlamaForCausalLM"
    },
222
    "MistralModel": ("llama", "LlamaForCausalLM"),
223
    "ModernBertModel": ("modernbert", "ModernBertModel"),
224
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
225
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
226
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
227
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
228
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
229
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
230
231
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
232
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
233
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
234
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
235
    # [Multimodal]
236
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
237
238
239
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
240
    ),
241
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
242
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
243
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
244
245
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
246
    # models for the time being.
247
248
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
249
250
}

251
252
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
253
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
254
255
256
257
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
258
259
260
261
262
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
263
264
265
266
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
267
268
269
270
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
271
272
273
274
275
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
276
277
}

278
_MULTIMODAL_MODELS = {
279
    # [Decoder-only]
280
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
281
282
283
284
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
285
286
287
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
288
    ),
289
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
290
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
291
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
292
293
294
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
295
    ),
296
297
298
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
299
    ),
300
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
301
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
302
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
303
304
305
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
306
    ),
307
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
308
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
309
310
311
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
312
    ),
313
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
314
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
315
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
316
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
317
318
319
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
320
    ),
321
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
322
323
324
325
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
326
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
327
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
328
329
330
331
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
332
333
334
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
335
    ),
336
337
338
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
339
    ),
340
341
342
343
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
344
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
345
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
346
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
347
348
349
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
350
    ),
351
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
352
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
353
354
355
356
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
357
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
358
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
359
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
360
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
361
362
363
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
364
    ),
365
366
367
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
368
    ),
369
370
371
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
372
    ),
373
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
374
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
375
376
377
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
378
    ),
379
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
380
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
381
382
383
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
384
    ),
385
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
386
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
387
    "Ovis": ("ovis", "Ovis"),
388
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
389
390
391
392
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
393
394
395
396
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
397
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
398
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
399
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
400
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
401
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
402
403
404
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
405
    ),
406
407
408
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
409
    ),
410
411
412
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
413
    ),
414
415
416
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
417
    ),
418
419
420
421
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
422
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
423
424
425
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
426
    ),
427
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
428
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
429
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
430
431
432
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
433
    ),
434
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
435
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
436
    "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"),  # noqa: E501
437
    # [Encoder-decoder]
438
439
440
441
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
442
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
443
}
444
445

_SPECULATIVE_DECODING_MODELS = {
446
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
447
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
448
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
449
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
450
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
451
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
452
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
453
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
454
455
456
457
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
458
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
459
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
460
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
461
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
462
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
463
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
464
    "MedusaModel": ("medusa", "Medusa"),
465
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
466
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
467
468
469
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
470
}
471

472
_TRANSFORMERS_SUPPORTED_MODELS = {
473
474
475
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
476
477
478
479
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
480
481
482
}

_TRANSFORMERS_BACKEND_MODELS = {
483
    # Text generation models
484
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
503
    "TransformersForSequenceClassification": (
504
        "transformers",
505
        "TransformersForSequenceClassification",
506
    ),
507
    "TransformersMoEForSequenceClassification": (
508
        "transformers",
509
        "TransformersMoEForSequenceClassification",
510
    ),
511
512
513
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
514
    ),
515
}
516

517
_VLLM_MODELS = {
518
    **_TEXT_GENERATION_MODELS,
519
    **_EMBEDDING_MODELS,
520
    **_CROSS_ENCODER_MODELS,
521
    **_MULTIMODAL_MODELS,
522
    **_SPECULATIVE_DECODING_MODELS,
523
524
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
525
526
}

527
528
529
530
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
531
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
532

533
_PREVIOUSLY_SUPPORTED_MODELS = {
534
    "MotifForCausalLM": "0.10.2",
535
    "Phi3SmallForCausalLM": "0.9.2",
536
    "Phi4FlashForCausalLM": "0.10.2",
537
    "Phi4MultimodalForCausalLM": "0.12.0",
538
539
540
541
542
543
544
545
546
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
547

548

549
550
@dataclass(frozen=True)
class _ModelInfo:
551
    architecture: str
552
    is_text_generation_model: bool
553
    is_pooling_model: bool
554
    attn_type: AttnTypeStr
555
556
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
557
    supports_cross_encoding: bool
558
    supports_multimodal: bool
559
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
560
    requires_raw_input_tokens: bool
561
    supports_multimodal_encoder_tp_data: bool
562
    supports_pp: bool
563
564
    has_inner_state: bool
    is_attention_free: bool
565
    is_hybrid: bool
566
    has_noops: bool
567
    supports_mamba_prefix_caching: bool
568
    supports_transcription: bool
569
    supports_transcription_only: bool
570
571

    @staticmethod
572
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
573
        return _ModelInfo(
574
            architecture=model.__name__,
575
            is_text_generation_model=is_text_generation_model(model),
576
            is_pooling_model=is_pooling_model(model),
577
578
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
579
            attn_type=get_attn_type(model),
580
            supports_cross_encoding=supports_cross_encoding(model),
581
            supports_multimodal=supports_multimodal(model),
582
583
584
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
585
            requires_raw_input_tokens=requires_raw_input_tokens(model),
586
587
588
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
589
            supports_pp=supports_pp(model),
590
591
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
592
            is_hybrid=is_hybrid(model),
593
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
594
            supports_transcription=supports_transcription(model),
595
596
597
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
598
            has_noops=has_noops(model),
599
        )
600
601


602
603
604
605
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
606

607
    @abstractmethod
608
    def load_model_cls(self) -> type[nn.Module]:
609
        raise NotImplementedError
610
611


612
613
614
615
616
617
618
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
619
    model_cls: type[nn.Module]
620
621

    @staticmethod
622
    def from_model_cls(model_cls: type[nn.Module]):
623
624
625
626
627
628
629
630
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

631
    def load_model_cls(self) -> type[nn.Module]:
632
633
634
635
636
637
638
639
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
640

641
642
643
    module_name: str
    class_name: str

644
645
646
647
648
649
650
651
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

652
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
653
654
        try:
            try:
655
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
656
657
658
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
659
                logger.debug(
660
                    "Cached model info file for class %s.%s not found",
661
662
663
                    self.module_name,
                    self.class_name,
                )
664
665
666
                return None

            if mi_dict["hash"] != module_hash:
667
                logger.debug(
668
                    "Cached model info file for class %s.%s is stale",
669
670
671
                    self.module_name,
                    self.class_name,
                )
672
673
674
675
676
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
677
            logger.debug(
678
                "Cached model info for class %s.%s error. ",
679
680
681
                self.module_name,
                self.class_name,
            )
682
683
            return None

684
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
685
686
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
687

688
689
690
691
692
693
694
695
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
696
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
697
698
699
700
701
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
702
    def inspect_model_cls(self) -> _ModelInfo:
703
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
704
        module_hash = None
705

706
707
        if model_path.exists():
            with open(model_path, "rb") as f:
708
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
709
710
711

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
712
                logger.debug(
713
                    "Loaded model info for class %s.%s from cache",
714
715
716
                    self.module_name,
                    self.class_name,
                )
717
718
                return mi
            else:
719
                logger.debug(
720
                    "Cache model info for class %s.%s miss. Loading model instead.",
721
722
723
                    self.module_name,
                    self.class_name,
                )
724
725
726

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
727
728
729
730
731
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
732
733

        # save cache file
734
735
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
736
737

        return mi
738

739
    def load_model_cls(self) -> type[nn.Module]:
740
741
742
743
744
745
746
747
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
748
) -> type[nn.Module] | None:
749
    from vllm.platforms import current_platform
750

751
    current_platform.verify_model_arch(model_arch)
752
753
754
    try:
        return model.load_model_cls()
    except Exception:
755
        logger.exception("Error in loading model architecture '%s'", model_arch)
756
        return None
757
758


759
760
761
762
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
763
) -> _ModelInfo | None:
764
765
766
    try:
        return model.inspect_model_cls()
    except Exception:
767
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
768
        return None
769
770


771
772
773
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
774
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
775

776
    def get_supported_archs(self) -> Set[str]:
777
        return self.models.keys()
778

779
780
781
    def register_model(
        self,
        model_arch: str,
782
        model_cls: type[nn.Module] | str,
783
    ) -> None:
784
785
786
        """
        Register an external model to be used in vLLM.

787
        `model_cls` can be either:
788

789
        - A [`torch.nn.Module`][] class directly referencing the model.
790
        - A string in the format `<module>:<class>` which can be used to
791
792
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
793
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
794
        """
795
796
797
798
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

799
        if model_arch in self.models:
800
801
            logger.warning(
                "Model architecture %s is already registered, and will be "
802
803
804
805
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
806
807
808
809
810
811

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
812

813
            model = _LazyRegisteredModel(*split_str)
814
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
815
            model = _RegisteredModel.from_model_cls(model_cls)
816
        else:
817
818
819
820
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
821
            raise TypeError(msg)
822

823
        self.models[model_arch] = model
824

825
    def _raise_for_unsupported(self, architectures: list[str]):
826
        all_supported_archs = self.get_supported_archs()
827

828
829
830
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
831
832
                "to be inspected. Please check the logs for more details."
            )
833

834
835
836
837
838
839
840
841
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
842
843
                    "use this model architecture."
                )
844

845
846
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
847
848
            f"Supported architectures: {all_supported_archs}"
        )
849

850
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
851
852
        if model_arch not in self.models:
            return None
853

854
        return _try_load_model_cls(model_arch, self.models[model_arch])
855

856
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
857
858
        if model_arch not in self.models:
            return None
859

860
861
862
863
864
865
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
866
    ) -> str | None:
867
868
869
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

870
871
872
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
906
                if model_config.model_impl != "transformers":
907
908
909
910
911
912
913
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
914
915
                    "'auto_map' (relevant if the model is custom)."
                )
916
917

        if not model_module.is_backend_compatible():
918
            if model_config.model_impl != "transformers":
919
                return None
920

921
922
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
923
924
                "is not compatible with vLLM."
            )
925

926
        return model_config._get_transformers_backend_cls()
927

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
953

954
955
    def inspect_model_cls(
        self,
956
        architectures: str | list[str],
957
        model_config: ModelConfig,
958
    ) -> tuple[_ModelInfo, str]:
959
960
        if isinstance(architectures, str):
            architectures = [architectures]
961
962
        if not architectures:
            raise ValueError("No model architectures are specified")
963
964

        # Require transformers impl
965
        if model_config.model_impl == "transformers":
966
            arch = self._try_resolve_transformers(architectures[0], model_config)
967
968
969
970
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
971
        elif model_config.model_impl == "terratorch":
972
973
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
974

975
        # Fallback to transformers impl (after resolving convert_type)
976
977
978
979
980
981
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
982
983
984
985
986
987
988
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
989
            model_info = self._try_inspect_model_cls(normalized_arch)
990
            if model_info is not None:
991
                return (model_info, arch)
992

993
        # Fallback to transformers impl (before resolving runner_type)
994
995
996
997
998
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
999
1000
1001
1002
1003
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1004
        return self._raise_for_unsupported(architectures)
1005

1006
1007
    def resolve_model_cls(
        self,
1008
        architectures: str | list[str],
1009
        model_config: ModelConfig,
1010
    ) -> tuple[type[nn.Module], str]:
1011
1012
        if isinstance(architectures, str):
            architectures = [architectures]
1013
1014
        if not architectures:
            raise ValueError("No model architectures are specified")
1015
1016

        # Require transformers impl
1017
        if model_config.model_impl == "transformers":
1018
            arch = self._try_resolve_transformers(architectures[0], model_config)
1019
1020
1021
1022
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1023
        elif model_config.model_impl == "terratorch":
1024
1025
1026
1027
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1028

1029
        # Fallback to transformers impl (after resolving convert_type)
1030
1031
1032
1033
1034
1035
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1036
1037
1038
1039
1040
1041
1042
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1043
            model_cls = self._try_load_model_cls(normalized_arch)
1044
1045
            if model_cls is not None:
                return (model_cls, arch)
1046

1047
        # Fallback to transformers impl (before resolving runner_type)
1048
1049
1050
1051
1052
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1053
1054
1055
1056
1057
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1058
        return self._raise_for_unsupported(architectures)
1059

1060
1061
    def is_text_generation_model(
        self,
1062
        architectures: str | list[str],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
        return model_cls.is_text_generation_model
1067

1068
    def is_pooling_model(
1069
        self,
1070
        architectures: str | list[str],
1071
        model_config: ModelConfig,
1072
    ) -> bool:
1073
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1074
        return model_cls.is_pooling_model
1075

1076
1077
    def is_cross_encoder_model(
        self,
1078
        architectures: str | list[str],
1079
        model_config: ModelConfig,
1080
    ) -> bool:
1081
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1082
        return model_cls.supports_cross_encoding
1083

1084
1085
    def is_multimodal_model(
        self,
1086
        architectures: str | list[str],
1087
        model_config: ModelConfig,
1088
    ) -> bool:
1089
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1090
        return model_cls.supports_multimodal
1091

1092
    def is_multimodal_raw_input_only_model(
1093
        self,
1094
        architectures: str | list[str],
1095
        model_config: ModelConfig,
1096
    ) -> bool:
1097
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1098
        return model_cls.supports_multimodal_raw_input_only
1099

1100
1101
    def is_pp_supported_model(
        self,
1102
        architectures: str | list[str],
1103
        model_config: ModelConfig,
1104
    ) -> bool:
1105
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1106
        return model_cls.supports_pp
1107

1108
1109
    def model_has_inner_state(
        self,
1110
        architectures: str | list[str],
1111
        model_config: ModelConfig,
1112
    ) -> bool:
1113
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1114
        return model_cls.has_inner_state
1115

1116
1117
    def is_attention_free_model(
        self,
1118
        architectures: str | list[str],
1119
        model_config: ModelConfig,
1120
    ) -> bool:
1121
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1122
        return model_cls.is_attention_free
1123

1124
1125
    def is_hybrid_model(
        self,
1126
        architectures: str | list[str],
1127
        model_config: ModelConfig,
1128
    ) -> bool:
1129
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1130
1131
        return model_cls.is_hybrid

1132
1133
    def is_noops_model(
        self,
1134
        architectures: str | list[str],
1135
        model_config: ModelConfig,
1136
    ) -> bool:
1137
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1138
1139
        return model_cls.has_noops

1140
1141
    def is_transcription_model(
        self,
1142
        architectures: str | list[str],
1143
        model_config: ModelConfig,
1144
    ) -> bool:
1145
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1146
1147
        return model_cls.supports_transcription

1148
1149
    def is_transcription_only_model(
        self,
1150
        architectures: str | list[str],
1151
        model_config: ModelConfig,
1152
    ) -> bool:
1153
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1154
1155
        return model_cls.supports_transcription_only

1156

1157
1158
1159
1160
1161
1162
1163
1164
1165
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1166
1167
1168
1169
1170

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1171
1172
1173
1174
1175
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1176
        # `cloudpickle` allows pickling lambda functions directly
1177
        import cloudpickle
1178

1179
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1180
1181
1182

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1183
1184
1185
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1186
1187
1188
1189
1190
1191

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1192
1193
1194
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1195

1196
        with open(output_filepath, "rb") as f:
1197
1198
1199
1200
1201
1202
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1203

1204
1205
1206
1207
1208
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1209
1210
1211

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1212
1213
1214


if __name__ == "__main__":
1215
    _run()