registry.py 48.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
64
65
    is_pooling_model,
    is_text_generation_model,
)
66
67
68

logger = init_logger(__name__)

69
70
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
71
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
72
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
73
74
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
75
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
76
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
77
78
79
80
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
81
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
82
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
83
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
84
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
85
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
86
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
87
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
88
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
90
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
91
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
92
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
93
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
94
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
95
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
97
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
98
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
99
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
100
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
101
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
102
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
103
104
105
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
106
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
107
108
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
109
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
110
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
111
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
112
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
113
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
114
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
115
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
116
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
117
118
119
120
121
122
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
123
124
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
125
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
126
127
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
128
129
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
130
    "HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
131
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
132
133
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
134
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
135
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
136
137
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
138
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
139
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
140
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
141
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
142
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
143
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
144
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
145
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
146
147
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
148
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
149
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
150
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
151
152
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
153
154
155
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
156
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
157
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
158
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
159
160
161
162
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
163
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
164
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
165
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
166
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
167
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
168
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
169
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
170
171
172
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
173
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
174
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
175
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
176
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
177
178
179
180
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
181
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
182
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
183
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
184
185
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
186
187
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
188
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
189
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
190
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
191
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
192
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
193
194
195
196
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
197
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
198
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
199
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
200
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
201
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
202
203
204
}

_EMBEDDING_MODELS = {
205
    # [Text-only]
206
    "BertModel": ("bert", "BertEmbeddingModel"),
207
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
208
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
209
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
210
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
211
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
212
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
213
    "GritLM": ("gritlm", "GritLM"),
214
215
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
216
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
217
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
218
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
219
    "LlamaModel": ("llama", "LlamaForCausalLM"),
220
221
    **{
        # Multiple models share the same architecture, so we include them all
222
223
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
224
225
        if arch == "LlamaForCausalLM"
    },
226
    "MistralModel": ("llama", "LlamaForCausalLM"),
227
    "ModernBertModel": ("modernbert", "ModernBertModel"),
228
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
229
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
230
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
231
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
232
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
233
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
234
235
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
236
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
237
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
238
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
239
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
240
    # [Multimodal]
241
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
242
243
244
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
245
    ),
246
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
247
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
248
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
249
    # Technically Terratorch models work on images, both in
250
251
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
252
253
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
254
255
}

256
257
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
258
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
259
260
261
262
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
263
264
265
266
267
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
268
269
270
271
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
272
273
274
275
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
276
277
278
279
280
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
281
282
}

283
_MULTIMODAL_MODELS = {
284
    # [Decoder-only]
285
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
286
287
288
289
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
290
291
292
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
293
    ),
294
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
295
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
296
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
297
298
299
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
300
    ),
301
302
303
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
304
    ),
305
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
306
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
307
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
308
309
310
311
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
312
313
314
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
315
    ),
316
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
317
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
318
319
320
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
321
    ),
322
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
323
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
324
325
326
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
327
328
329
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
330
    ),
331
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
332
333
334
335
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
336
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
337
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
338
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
339
340
341
342
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
343
344
345
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
346
    ),
347
348
349
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
350
    ),
351
352
353
354
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
355
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
356
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
357
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
358
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
359
360
361
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
362
    ),
363
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
364
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
365
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
366
367
368
369
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
370
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
371
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
372
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
373
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
374
375
376
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
377
    ),
378
379
380
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
381
    ),
382
383
384
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
385
    ),
386
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
387
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
388
389
390
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
391
    ),
392
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
393
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
394
395
396
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
397
    ),
398
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
399
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
400
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
401
    "Ovis": ("ovis", "Ovis"),
402
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
403
404
405
406
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
407
408
409
410
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
411
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
412
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
413
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
414
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
415
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
416
417
418
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
419
    ),
420
421
422
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
423
    ),
424
425
426
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
427
    ),
428
429
430
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
431
    ),
432
433
434
435
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
436
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
437
438
439
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
440
    ),
441
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
zhuwenwen's avatar
zhuwenwen committed
442
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
443
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
444
445
446
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
447
    ),
448
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
449
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
450
    "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"),  # noqa: E501
451
    # [Encoder-decoder]
452
453
454
455
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
456
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
457
}
458
459

_SPECULATIVE_DECODING_MODELS = {
460
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
461
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
462
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
463
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
464
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
465
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
466
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
467
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
468
469
470
471
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
472
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
473
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
474
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
475
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
zhuwenwen's avatar
zhuwenwen committed
476
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
zhuwenwen's avatar
zhuwenwen committed
477
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
478
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
479
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
480
    "MedusaModel": ("medusa", "Medusa"),
481
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
482
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
483
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
484
485
486
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
487
}
488

489
_TRANSFORMERS_SUPPORTED_MODELS = {
490
491
492
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
493
494
495
496
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
497
498
499
}

_TRANSFORMERS_BACKEND_MODELS = {
500
    # Text generation models
501
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
520
    "TransformersForSequenceClassification": (
521
        "transformers",
522
        "TransformersForSequenceClassification",
523
    ),
524
    "TransformersMoEForSequenceClassification": (
525
        "transformers",
526
        "TransformersMoEForSequenceClassification",
527
    ),
528
529
530
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
531
    ),
532
}
533

534
_VLLM_MODELS = {
535
    **_TEXT_GENERATION_MODELS,
536
    **_EMBEDDING_MODELS,
537
    **_CROSS_ENCODER_MODELS,
538
    **_MULTIMODAL_MODELS,
539
    **_SPECULATIVE_DECODING_MODELS,
540
541
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
542
543
}

544
545
546
547
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
548
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
549

550
_PREVIOUSLY_SUPPORTED_MODELS = {
551
    "MotifForCausalLM": "0.10.2",
552
    "Phi3SmallForCausalLM": "0.9.2",
553
    "Phi4FlashForCausalLM": "0.10.2",
554
    "Phi4MultimodalForCausalLM": "0.12.0",
555
556
557
558
559
560
561
562
563
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
564

565

566
567
@dataclass(frozen=True)
class _ModelInfo:
568
    architecture: str
569
    is_text_generation_model: bool
570
    is_pooling_model: bool
571
    attn_type: AttnTypeStr
572
573
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
574
    supports_cross_encoding: bool
575
    supports_multimodal: bool
576
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
577
    requires_raw_input_tokens: bool
578
    supports_multimodal_encoder_tp_data: bool
579
    supports_pp: bool
580
581
    has_inner_state: bool
    is_attention_free: bool
582
    is_hybrid: bool
583
    has_noops: bool
584
    supports_mamba_prefix_caching: bool
585
    supports_transcription: bool
586
    supports_transcription_only: bool
587
588

    @staticmethod
589
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
590
        return _ModelInfo(
591
            architecture=model.__name__,
592
            is_text_generation_model=is_text_generation_model(model),
593
            is_pooling_model=is_pooling_model(model),
594
595
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
596
            attn_type=get_attn_type(model),
597
            supports_cross_encoding=supports_cross_encoding(model),
598
            supports_multimodal=supports_multimodal(model),
599
600
601
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
602
            requires_raw_input_tokens=requires_raw_input_tokens(model),
603
604
605
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
606
            supports_pp=supports_pp(model),
607
608
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
609
            is_hybrid=is_hybrid(model),
610
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
611
            supports_transcription=supports_transcription(model),
612
613
614
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
615
            has_noops=has_noops(model),
616
        )
617
618


619
620
621
622
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
623

624
    @abstractmethod
625
    def load_model_cls(self) -> type[nn.Module]:
626
        raise NotImplementedError
627
628


629
630
631
632
633
634
635
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
636
    model_cls: type[nn.Module]
637
638

    @staticmethod
639
    def from_model_cls(model_cls: type[nn.Module]):
640
641
642
643
644
645
646
647
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

648
    def load_model_cls(self) -> type[nn.Module]:
649
650
651
652
653
654
655
656
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
657

658
659
660
    module_name: str
    class_name: str

661
662
663
664
665
666
667
668
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

669
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
670
671
        try:
            try:
672
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
673
674
675
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
676
                logger.debug(
677
                    "Cached model info file for class %s.%s not found",
678
679
680
                    self.module_name,
                    self.class_name,
                )
681
682
683
                return None

            if mi_dict["hash"] != module_hash:
684
                logger.debug(
685
                    "Cached model info file for class %s.%s is stale",
686
687
688
                    self.module_name,
                    self.class_name,
                )
689
690
691
692
693
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
694
            logger.debug(
695
                "Cached model info for class %s.%s error. ",
696
697
698
                self.module_name,
                self.class_name,
            )
699
700
            return None

701
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
702
703
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
704

705
706
707
708
709
710
711
712
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
713
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
714
715
716
717
718
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
719
    def inspect_model_cls(self) -> _ModelInfo:
720
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
721
        module_hash = None
722

723
724
        if model_path.exists():
            with open(model_path, "rb") as f:
725
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
726
727
728

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
729
                logger.debug(
730
                    "Loaded model info for class %s.%s from cache",
731
732
733
                    self.module_name,
                    self.class_name,
                )
734
735
                return mi
            else:
736
                logger.debug(
737
                    "Cache model info for class %s.%s miss. Loading model instead.",
738
739
740
                    self.module_name,
                    self.class_name,
                )
741
742
743

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
744
745
746
747
748
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
749
750

        # save cache file
751
752
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
753
754

        return mi
755

756
    def load_model_cls(self) -> type[nn.Module]:
757
758
759
760
761
762
763
764
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
765
) -> type[nn.Module] | None:
766
    from vllm.platforms import current_platform
767

768
    current_platform.verify_model_arch(model_arch)
769
770
771
    try:
        return model.load_model_cls()
    except Exception:
772
        logger.exception("Error in loading model architecture '%s'", model_arch)
773
        return None
774
775


776
777
778
779
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
780
) -> _ModelInfo | None:
781
782
783
    try:
        return model.inspect_model_cls()
    except Exception:
784
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
785
        return None
786
787


788
789
790
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
791
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
792

793
    def get_supported_archs(self) -> Set[str]:
794
        return self.models.keys()
795

796
797
798
    def register_model(
        self,
        model_arch: str,
799
        model_cls: type[nn.Module] | str,
800
    ) -> None:
801
802
803
        """
        Register an external model to be used in vLLM.

804
        `model_cls` can be either:
805

806
        - A [`torch.nn.Module`][] class directly referencing the model.
807
        - A string in the format `<module>:<class>` which can be used to
808
809
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
810
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
811
        """
812
813
814
815
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

816
        if model_arch in self.models:
817
818
            logger.warning(
                "Model architecture %s is already registered, and will be "
819
820
821
822
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
823
824
825
826
827
828

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
829

830
            model = _LazyRegisteredModel(*split_str)
831
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
832
            model = _RegisteredModel.from_model_cls(model_cls)
833
        else:
834
835
836
837
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
838
            raise TypeError(msg)
839

840
        self.models[model_arch] = model
841

842
    def _raise_for_unsupported(self, architectures: list[str]):
843
        all_supported_archs = self.get_supported_archs()
844

845
846
847
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
848
849
                "to be inspected. Please check the logs for more details."
            )
850

851
852
853
854
855
856
857
858
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
859
860
                    "use this model architecture."
                )
861

862
863
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
864
865
            f"Supported architectures: {all_supported_archs}"
        )
866

867
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
868
869
        if model_arch not in self.models:
            return None
870

871
        return _try_load_model_cls(model_arch, self.models[model_arch])
872

873
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
874
875
        if model_arch not in self.models:
            return None
876

877
878
879
880
881
882
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
883
    ) -> str | None:
884
885
886
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

887
888
889
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
906
                        trust_remote_code=model_config.trust_remote_code,
907
908
909
910
911
912
913
914
915
916
917
918
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
919
                        trust_remote_code=model_config.trust_remote_code,
920
921
922
923
924
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
925
                if model_config.model_impl != "transformers":
926
927
928
929
930
931
932
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
933
934
                    "'auto_map' (relevant if the model is custom)."
                )
935
936

        if not model_module.is_backend_compatible():
937
            if model_config.model_impl != "transformers":
938
                return None
939

940
941
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
942
943
                "is not compatible with vLLM."
            )
944

945
        return model_config._get_transformers_backend_cls()
946

947
    def _normalize_arch(
948
        self,
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match
964

965
966
967
968
969
            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch
970

971
        return architecture
972

973
974
    def inspect_model_cls(
        self,
975
        architectures: str | list[str],
976
        model_config: ModelConfig,
977
    ) -> tuple[_ModelInfo, str]:
978
979
        if isinstance(architectures, str):
            architectures = [architectures]
980
981
        if not architectures:
            raise ValueError("No model architectures are specified")
982
983

        # Require transformers impl
984
        if model_config.model_impl == "transformers":
985
            arch = self._try_resolve_transformers(architectures[0], model_config)
986
987
988
989
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
990
        elif model_config.model_impl == "terratorch":
991
992
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
993

994
        # Fallback to transformers impl (after resolving convert_type)
995
996
997
998
999
1000
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1001
1002
1003
1004
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1005

1006
        for arch in architectures:
1007
            normalized_arch = self._normalize_arch(arch, model_config)
1008
            model_info = self._try_inspect_model_cls(normalized_arch)
1009
            if model_info is not None:
1010
                return (model_info, arch)
1011

1012
        # Fallback to transformers impl (before resolving runner_type)
1013
1014
1015
1016
1017
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1018
1019
1020
1021
1022
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1023
        return self._raise_for_unsupported(architectures)
1024

1025
1026
    def resolve_model_cls(
        self,
1027
        architectures: str | list[str],
1028
        model_config: ModelConfig,
1029
    ) -> tuple[type[nn.Module], str]:
1030
1031
        if isinstance(architectures, str):
            architectures = [architectures]
1032
1033
        if not architectures:
            raise ValueError("No model architectures are specified")
1034
1035

        # Require transformers impl
1036
        if model_config.model_impl == "transformers":
1037
            arch = self._try_resolve_transformers(architectures[0], model_config)
1038
1039
1040
1041
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1042
        elif model_config.model_impl == "terratorch":
1043
1044
1045
1046
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1047

1048
        # Fallback to transformers impl (after resolving convert_type)
1049
1050
1051
1052
1053
1054
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1055
1056
1057
1058
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1059

1060
        for arch in architectures:
1061
            normalized_arch = self._normalize_arch(arch, model_config)
1062
            model_cls = self._try_load_model_cls(normalized_arch)
1063
1064
            if model_cls is not None:
                return (model_cls, arch)
1065

1066
        # Fallback to transformers impl (before resolving runner_type)
1067
1068
1069
1070
1071
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1072
1073
1074
1075
1076
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1077
        return self._raise_for_unsupported(architectures)
1078

1079
1080
    def is_text_generation_model(
        self,
1081
        architectures: str | list[str],
1082
        model_config: ModelConfig,
1083
    ) -> bool:
1084
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1085
        return model_cls.is_text_generation_model
1086

1087
    def is_pooling_model(
1088
        self,
1089
        architectures: str | list[str],
1090
        model_config: ModelConfig,
1091
    ) -> bool:
1092
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1093
        return model_cls.is_pooling_model
1094

1095
1096
    def is_cross_encoder_model(
        self,
1097
        architectures: str | list[str],
1098
        model_config: ModelConfig,
1099
    ) -> bool:
1100
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1101
        return model_cls.supports_cross_encoding
1102

1103
1104
    def is_multimodal_model(
        self,
1105
        architectures: str | list[str],
1106
        model_config: ModelConfig,
1107
    ) -> bool:
1108
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1109
        return model_cls.supports_multimodal
1110

1111
    def is_multimodal_raw_input_only_model(
1112
        self,
1113
        architectures: str | list[str],
1114
        model_config: ModelConfig,
1115
    ) -> bool:
1116
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1117
        return model_cls.supports_multimodal_raw_input_only
1118

1119
1120
    def is_pp_supported_model(
        self,
1121
        architectures: str | list[str],
1122
        model_config: ModelConfig,
1123
    ) -> bool:
1124
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1125
        return model_cls.supports_pp
1126

1127
1128
    def model_has_inner_state(
        self,
1129
        architectures: str | list[str],
1130
        model_config: ModelConfig,
1131
    ) -> bool:
1132
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1133
        return model_cls.has_inner_state
1134

1135
1136
    def is_attention_free_model(
        self,
1137
        architectures: str | list[str],
1138
        model_config: ModelConfig,
1139
    ) -> bool:
1140
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1141
        return model_cls.is_attention_free
1142

1143
1144
    def is_hybrid_model(
        self,
1145
        architectures: str | list[str],
1146
        model_config: ModelConfig,
1147
    ) -> bool:
1148
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1149
1150
        return model_cls.is_hybrid

1151
1152
    def is_noops_model(
        self,
1153
        architectures: str | list[str],
1154
        model_config: ModelConfig,
1155
    ) -> bool:
1156
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1157
1158
        return model_cls.has_noops

1159
1160
    def is_transcription_model(
        self,
1161
        architectures: str | list[str],
1162
        model_config: ModelConfig,
1163
    ) -> bool:
1164
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1165
1166
        return model_cls.supports_transcription

1167
1168
    def is_transcription_only_model(
        self,
1169
        architectures: str | list[str],
1170
        model_config: ModelConfig,
1171
    ) -> bool:
1172
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1173
1174
        return model_cls.supports_transcription_only

1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1185
1186
1187
1188
1189

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1190
1191
1192
1193
1194
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1195
        # `cloudpickle` allows pickling lambda functions directly
1196
        import cloudpickle
1197

1198
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1199
1200
1201

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1202
1203
1204
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1205
1206
1207
1208
1209
1210

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1211
1212
1213
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1214

1215
        with open(output_filepath, "rb") as f:
1216
1217
1218
1219
1220
1221
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1222

1223
1224
1225
1226
1227
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1228
1229
1230

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1231
1232
1233


if __name__ == "__main__":
1234
    _run()