registry.py 48.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
64
65
    is_pooling_model,
    is_text_generation_model,
)
66
67
68

logger = init_logger(__name__)

69
70
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
71
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
72
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
73
74
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
75
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
76
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
77
78
79
80
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
81
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
82
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
83
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
84
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
85
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
86
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
87
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
88
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
90
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
91
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
92
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
93
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
94
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
95
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
97
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
98
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
99
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
100
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
101
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
102
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
103
104
105
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
106
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
107
108
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
109
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
110
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
111
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
112
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
113
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
115
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
116
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
117
118
119
120
121
122
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
123
124
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
125
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
126
127
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
128
129
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
130
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
131
132
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
133
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
134
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
135
136
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
137
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
138
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
139
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
140
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
141
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
142
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
143
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
144
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
145
146
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
147
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
148
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
149
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
150
151
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
152
153
154
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
155
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
156
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
157
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
158
159
160
161
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
162
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
163
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
164
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
165
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
166
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
167
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
168
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
169
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
170
171
172
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
173
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
174
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
175
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
176
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
177
178
179
180
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
181
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
182
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
183
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
184
185
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
186
187
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
188
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
189
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
190
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
191
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
192
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
193
194
195
196
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
197
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
198
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
199
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
200
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
201
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
202
203
204
}

_EMBEDDING_MODELS = {
205
    # [Text-only]
206
    "BertModel": ("bert", "BertEmbeddingModel"),
207
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
208
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
209
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
210
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
211
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
212
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
213
    "GritLM": ("gritlm", "GritLM"),
214
215
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
216
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
217
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
218
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
219
    "LlamaModel": ("llama", "LlamaForCausalLM"),
220
221
    **{
        # Multiple models share the same architecture, so we include them all
222
223
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
224
225
        if arch == "LlamaForCausalLM"
    },
226
    "MistralModel": ("llama", "LlamaForCausalLM"),
227
    "ModernBertModel": ("modernbert", "ModernBertModel"),
228
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
229
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
230
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
231
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
232
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
233
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
234
235
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
236
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
237
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
238
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
239
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
240
    # [Multimodal]
241
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
242
243
244
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
245
    ),
246
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
247
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
248
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
249
250
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
251
    # models for the time being.
252
253
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
254
255
}

256
257
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
258
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
259
260
261
262
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
263
264
265
266
267
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
268
269
270
271
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
272
273
274
275
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
276
277
278
279
280
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
281
282
}

283
_MULTIMODAL_MODELS = {
284
    # [Decoder-only]
285
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
286
287
288
289
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
290
291
292
293
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
294
295
296
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
297
    ),
298
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
299
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
300
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
301
302
303
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
304
    ),
305
306
307
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
308
    ),
309
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
310
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
311
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
312
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
313
314
315
316
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
317
318
319
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
320
    ),
321
322
323
324
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
325
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
326
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
327
328
329
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
330
    ),
331
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
332
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
333
334
335
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
336
337
338
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
339
    ),
340
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
341
342
343
344
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
345
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
346
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
347
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
348
349
350
351
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
352
353
354
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
355
    ),
356
357
358
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
359
    ),
zxy's avatar
zxy committed
360
361
362
363
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
364
365
366
367
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
368
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
369
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
370
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
371
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
372
373
374
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
375
    ),
376
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
377
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
378
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
379
380
381
382
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
383
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
384
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
385
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
386
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
387
388
389
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
390
    ),
391
392
393
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
394
    ),
395
396
397
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
398
    ),
399
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
400
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
401
402
403
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
404
    ),
405
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
406
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
407
408
409
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
410
    ),
411
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
412
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
413
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
414
415
416
417
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
418
    "Ovis": ("ovis", "Ovis"),
419
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
420
421
422
423
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
424
425
426
427
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
428
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
429
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
430
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
431
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
432
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
433
434
435
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
436
    ),
437
438
439
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
440
    ),
441
442
443
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
444
    ),
445
446
447
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
448
    ),
449
450
451
452
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
453
454
455
456
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
457
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
458
459
460
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
461
    ),
462
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
463
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
464
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
465
466
467
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
468
    ),
469
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
470
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
471
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
472
    # [Encoder-decoder]
473
474
475
476
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
477
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
478
}
479
480

_SPECULATIVE_DECODING_MODELS = {
481
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
482
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
483
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
484
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
485
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
486
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
487
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
488
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
489
490
491
492
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
493
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
494
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
495
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
496
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
497
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
498
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
499
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
500
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
501
    "MedusaModel": ("medusa", "Medusa"),
502
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
503
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
504
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
505
506
507
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
508
}
509

510
_TRANSFORMERS_SUPPORTED_MODELS = {
511
512
513
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
514
515
516
517
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
518
519
520
}

_TRANSFORMERS_BACKEND_MODELS = {
521
    # Text generation models
522
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
541
    "TransformersForSequenceClassification": (
542
        "transformers",
543
        "TransformersForSequenceClassification",
544
    ),
545
    "TransformersMoEForSequenceClassification": (
546
        "transformers",
547
        "TransformersMoEForSequenceClassification",
548
    ),
549
550
551
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
552
    ),
553
}
554

555
_VLLM_MODELS = {
556
    **_TEXT_GENERATION_MODELS,
557
    **_EMBEDDING_MODELS,
558
    **_CROSS_ENCODER_MODELS,
559
    **_MULTIMODAL_MODELS,
560
    **_SPECULATIVE_DECODING_MODELS,
561
562
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
563
564
}

565
566
567
568
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
569
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
570

571
_PREVIOUSLY_SUPPORTED_MODELS = {
572
    "MotifForCausalLM": "0.10.2",
573
    "Phi3SmallForCausalLM": "0.9.2",
574
    "Phi4FlashForCausalLM": "0.10.2",
575
    "Phi4MultimodalForCausalLM": "0.12.0",
576
577
578
579
580
581
582
583
584
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
585

586

587
588
@dataclass(frozen=True)
class _ModelInfo:
589
    architecture: str
590
    is_text_generation_model: bool
591
    is_pooling_model: bool
592
    attn_type: AttnTypeStr
593
594
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
595
    supports_cross_encoding: bool
596
    supports_multimodal: bool
597
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
598
    requires_raw_input_tokens: bool
599
    supports_multimodal_encoder_tp_data: bool
600
    supports_pp: bool
601
602
    has_inner_state: bool
    is_attention_free: bool
603
    is_hybrid: bool
604
    has_noops: bool
605
    supports_mamba_prefix_caching: bool
606
    supports_transcription: bool
607
    supports_transcription_only: bool
608
609

    @staticmethod
610
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
611
        return _ModelInfo(
612
            architecture=model.__name__,
613
            is_text_generation_model=is_text_generation_model(model),
614
            is_pooling_model=is_pooling_model(model),
615
616
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
617
            attn_type=get_attn_type(model),
618
            supports_cross_encoding=supports_cross_encoding(model),
619
            supports_multimodal=supports_multimodal(model),
620
621
622
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
623
            requires_raw_input_tokens=requires_raw_input_tokens(model),
624
625
626
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
627
            supports_pp=supports_pp(model),
628
629
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
630
            is_hybrid=is_hybrid(model),
631
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
632
            supports_transcription=supports_transcription(model),
633
634
635
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
636
            has_noops=has_noops(model),
637
        )
638
639


640
641
642
643
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
644

645
    @abstractmethod
646
    def load_model_cls(self) -> type[nn.Module]:
647
        raise NotImplementedError
648
649


650
651
652
653
654
655
656
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
657
    model_cls: type[nn.Module]
658
659

    @staticmethod
660
    def from_model_cls(model_cls: type[nn.Module]):
661
662
663
664
665
666
667
668
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

669
    def load_model_cls(self) -> type[nn.Module]:
670
671
672
673
674
675
676
677
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
678

679
680
681
    module_name: str
    class_name: str

682
683
684
685
686
687
688
689
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

690
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
691
692
        try:
            try:
693
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
694
695
696
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
697
                logger.debug(
698
                    "Cached model info file for class %s.%s not found",
699
700
701
                    self.module_name,
                    self.class_name,
                )
702
703
704
                return None

            if mi_dict["hash"] != module_hash:
705
                logger.debug(
706
                    "Cached model info file for class %s.%s is stale",
707
708
709
                    self.module_name,
                    self.class_name,
                )
710
711
712
713
714
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
715
            logger.debug(
716
                "Cached model info for class %s.%s error. ",
717
718
719
                self.module_name,
                self.class_name,
            )
720
721
            return None

722
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
723
724
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
725

726
727
728
729
730
731
732
733
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
734
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
735
736
737
738
739
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
740
    def inspect_model_cls(self) -> _ModelInfo:
741
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
742
        module_hash = None
743

744
745
        if model_path.exists():
            with open(model_path, "rb") as f:
746
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
747
748
749

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
750
                logger.debug(
751
                    "Loaded model info for class %s.%s from cache",
752
753
754
                    self.module_name,
                    self.class_name,
                )
755
756
                return mi
            else:
757
                logger.debug(
758
                    "Cache model info for class %s.%s miss. Loading model instead.",
759
760
761
                    self.module_name,
                    self.class_name,
                )
762
763
764

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
765
766
767
768
769
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
770
771

        # save cache file
772
773
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
774
775

        return mi
776

777
    def load_model_cls(self) -> type[nn.Module]:
778
779
780
781
782
783
784
785
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
786
) -> type[nn.Module] | None:
787
    from vllm.platforms import current_platform
788

789
    current_platform.verify_model_arch(model_arch)
790
791
792
    try:
        return model.load_model_cls()
    except Exception:
793
        logger.exception("Error in loading model architecture '%s'", model_arch)
794
        return None
795
796


797
798
799
800
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
801
) -> _ModelInfo | None:
802
803
804
    try:
        return model.inspect_model_cls()
    except Exception:
805
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
806
        return None
807
808


809
810
811
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
812
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
813

814
    def get_supported_archs(self) -> Set[str]:
815
        return self.models.keys()
816

817
818
819
    def register_model(
        self,
        model_arch: str,
820
        model_cls: type[nn.Module] | str,
821
    ) -> None:
822
823
824
        """
        Register an external model to be used in vLLM.

825
        `model_cls` can be either:
826

827
        - A [`torch.nn.Module`][] class directly referencing the model.
828
        - A string in the format `<module>:<class>` which can be used to
829
830
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
831
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
832
        """
833
834
835
836
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

837
        if model_arch in self.models:
838
839
            logger.warning(
                "Model architecture %s is already registered, and will be "
840
841
842
843
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
844
845
846
847
848
849

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
850

851
            model = _LazyRegisteredModel(*split_str)
852
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
853
            model = _RegisteredModel.from_model_cls(model_cls)
854
        else:
855
856
857
858
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
859
            raise TypeError(msg)
860

861
        self.models[model_arch] = model
862

863
    def _raise_for_unsupported(self, architectures: list[str]):
864
        all_supported_archs = self.get_supported_archs()
865

866
867
868
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
869
870
                "to be inspected. Please check the logs for more details."
            )
871

872
873
874
875
876
877
878
879
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
880
881
                    "use this model architecture."
                )
882

883
884
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
885
886
            f"Supported architectures: {all_supported_archs}"
        )
887

888
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
889
890
        if model_arch not in self.models:
            return None
891

892
        return _try_load_model_cls(model_arch, self.models[model_arch])
893

894
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
895
896
        if model_arch not in self.models:
            return None
897

898
899
900
901
902
903
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
904
    ) -> str | None:
905
906
907
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

908
909
910
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
927
                        trust_remote_code=model_config.trust_remote_code,
928
929
930
931
932
933
934
935
936
937
938
939
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
940
                        trust_remote_code=model_config.trust_remote_code,
941
942
943
944
945
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
946
                if model_config.model_impl != "transformers":
947
948
949
950
951
952
953
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
954
955
                    "'auto_map' (relevant if the model is custom)."
                )
956
957

        if not model_module.is_backend_compatible():
958
            if model_config.model_impl != "transformers":
959
                return None
960

961
962
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
963
964
                "is not compatible with vLLM."
            )
965

966
        return model_config._get_transformers_backend_cls()
967

968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
993

994
995
    def inspect_model_cls(
        self,
996
        architectures: str | list[str],
997
        model_config: ModelConfig,
998
    ) -> tuple[_ModelInfo, str]:
999
1000
        if isinstance(architectures, str):
            architectures = [architectures]
1001
1002
        if not architectures:
            raise ValueError("No model architectures are specified")
1003
1004

        # Require transformers impl
1005
        if model_config.model_impl == "transformers":
1006
            arch = self._try_resolve_transformers(architectures[0], model_config)
1007
1008
1009
1010
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1011
        elif model_config.model_impl == "terratorch":
1012
1013
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1014

1015
        # Fallback to transformers impl (after resolving convert_type)
1016
1017
1018
1019
1020
1021
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1022
1023
1024
1025
1026
1027
1028
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1029
            model_info = self._try_inspect_model_cls(normalized_arch)
1030
            if model_info is not None:
1031
                return (model_info, arch)
1032

1033
        # Fallback to transformers impl (before resolving runner_type)
1034
1035
1036
1037
1038
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1039
1040
1041
1042
1043
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1044
        return self._raise_for_unsupported(architectures)
1045

1046
1047
    def resolve_model_cls(
        self,
1048
        architectures: str | list[str],
1049
        model_config: ModelConfig,
1050
    ) -> tuple[type[nn.Module], str]:
1051
1052
        if isinstance(architectures, str):
            architectures = [architectures]
1053
1054
        if not architectures:
            raise ValueError("No model architectures are specified")
1055
1056

        # Require transformers impl
1057
        if model_config.model_impl == "transformers":
1058
            arch = self._try_resolve_transformers(architectures[0], model_config)
1059
1060
1061
1062
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1063
        elif model_config.model_impl == "terratorch":
1064
1065
1066
1067
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1068

1069
        # Fallback to transformers impl (after resolving convert_type)
1070
1071
1072
1073
1074
1075
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1076
1077
1078
1079
1080
1081
1082
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1083
            model_cls = self._try_load_model_cls(normalized_arch)
1084
1085
            if model_cls is not None:
                return (model_cls, arch)
1086

1087
        # Fallback to transformers impl (before resolving runner_type)
1088
1089
1090
1091
1092
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1093
1094
1095
1096
1097
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1098
        return self._raise_for_unsupported(architectures)
1099

1100
1101
    def is_text_generation_model(
        self,
1102
        architectures: str | list[str],
1103
        model_config: ModelConfig,
1104
    ) -> bool:
1105
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1106
        return model_cls.is_text_generation_model
1107

1108
    def is_pooling_model(
1109
        self,
1110
        architectures: str | list[str],
1111
        model_config: ModelConfig,
1112
    ) -> bool:
1113
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1114
        return model_cls.is_pooling_model
1115

1116
1117
    def is_cross_encoder_model(
        self,
1118
        architectures: str | list[str],
1119
        model_config: ModelConfig,
1120
    ) -> bool:
1121
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1122
        return model_cls.supports_cross_encoding
1123

1124
1125
    def is_multimodal_model(
        self,
1126
        architectures: str | list[str],
1127
        model_config: ModelConfig,
1128
    ) -> bool:
1129
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1130
        return model_cls.supports_multimodal
1131

1132
    def is_multimodal_raw_input_only_model(
1133
        self,
1134
        architectures: str | list[str],
1135
        model_config: ModelConfig,
1136
    ) -> bool:
1137
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1138
        return model_cls.supports_multimodal_raw_input_only
1139

1140
1141
    def is_pp_supported_model(
        self,
1142
        architectures: str | list[str],
1143
        model_config: ModelConfig,
1144
    ) -> bool:
1145
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1146
        return model_cls.supports_pp
1147

1148
1149
    def model_has_inner_state(
        self,
1150
        architectures: str | list[str],
1151
        model_config: ModelConfig,
1152
    ) -> bool:
1153
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1154
        return model_cls.has_inner_state
1155

1156
1157
    def is_attention_free_model(
        self,
1158
        architectures: str | list[str],
1159
        model_config: ModelConfig,
1160
    ) -> bool:
1161
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1162
        return model_cls.is_attention_free
1163

1164
1165
    def is_hybrid_model(
        self,
1166
        architectures: str | list[str],
1167
        model_config: ModelConfig,
1168
    ) -> bool:
1169
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1170
1171
        return model_cls.is_hybrid

1172
1173
    def is_noops_model(
        self,
1174
        architectures: str | list[str],
1175
        model_config: ModelConfig,
1176
    ) -> bool:
1177
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1178
1179
        return model_cls.has_noops

1180
1181
    def is_transcription_model(
        self,
1182
        architectures: str | list[str],
1183
        model_config: ModelConfig,
1184
    ) -> bool:
1185
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1186
1187
        return model_cls.supports_transcription

1188
1189
    def is_transcription_only_model(
        self,
1190
        architectures: str | list[str],
1191
        model_config: ModelConfig,
1192
    ) -> bool:
1193
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1194
1195
        return model_cls.supports_transcription_only

1196

1197
1198
1199
1200
1201
1202
1203
1204
1205
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1206
1207
1208
1209
1210

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1211
1212
1213
1214
1215
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1216
        # `cloudpickle` allows pickling lambda functions directly
1217
        import cloudpickle
1218

1219
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1220
1221
1222

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1223
1224
1225
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1226
1227
1228
1229
1230
1231

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1232
1233
1234
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1235

1236
        with open(output_filepath, "rb") as f:
1237
1238
1239
1240
1241
1242
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1243

1244
1245
1246
1247
1248
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1249
1250
1251

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1252
1253
1254


if __name__ == "__main__":
1255
    _run()