registry.py 49.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
84
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
85
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
86
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
87
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
88
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
92
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
93
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
94
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
95
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
98
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
99
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
100
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
101
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
102
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
103
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
104
105
106
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
107
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
108
109
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
110
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
111
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
112
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
116
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
117
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
118
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
119
120
121
122
123
124
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
125
126
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
127
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
128
129
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
130
131
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
132
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
133
134
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
135
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
136
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
137
138
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
139
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
140
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
141
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
142
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
143
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
144
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
145
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
147
148
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
149
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
150
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
151
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
152
153
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
154
155
156
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
157
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
158
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
159
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
160
161
162
163
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
164
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
165
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
166
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
167
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
168
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
170
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
171
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
173
174
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
175
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
176
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
177
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
178
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
179
180
181
182
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
183
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
184
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
185
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
186
187
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
188
189
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
190
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
191
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
192
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
193
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
194
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
195
196
197
198
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
199
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
202
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
203
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
204
205
206
}

_EMBEDDING_MODELS = {
207
    # [Text-only]
208
    "BertModel": ("bert", "BertEmbeddingModel"),
209
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
210
    "HF_ColBERT": ("colbert", "ColBERTModel"),
211
212
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
213
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
214
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
215
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
216
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
217
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
218
    "GritLM": ("gritlm", "GritLM"),
219
220
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
221
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
222
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
223
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
224
    "LlamaModel": ("llama", "LlamaForCausalLM"),
225
226
    **{
        # Multiple models share the same architecture, so we include them all
227
228
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
229
230
        if arch == "LlamaForCausalLM"
    },
231
    "MistralModel": ("llama", "LlamaForCausalLM"),
232
    "ModernBertModel": ("modernbert", "ModernBertModel"),
233
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
234
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
235
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
236
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
237
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
238
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
239
240
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
241
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
242
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
243
244
245
246
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
247
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
248
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
249
    # [Multimodal]
250
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
251
252
253
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
254
    ),
255
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
256
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
257
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
258
259
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
260
    # models for the time being.
261
262
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
263
264
}

265
266
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
267
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
268
269
270
271
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
272
273
274
275
276
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
277
278
279
280
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
281
282
283
284
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
285
286
287
288
289
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
290
291
}

292
_MULTIMODAL_MODELS = {
293
    # [Decoder-only]
294
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
295
296
297
298
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
299
300
301
302
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
303
304
305
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
306
    ),
307
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
308
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
309
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
310
311
312
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
313
    ),
314
315
316
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
317
    ),
318
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
319
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
320
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
321
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
322
323
324
325
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
326
327
328
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
329
    ),
330
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
331
332
333
334
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
335
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
336
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
337
338
339
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
340
    ),
341
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
342
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
343
344
345
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
346
347
348
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
349
    ),
350
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
351
352
353
354
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
355
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
356
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
357
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
358
359
360
361
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
362
363
364
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
365
    ),
366
367
368
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
369
    ),
zxy's avatar
zxy committed
370
371
372
373
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
374
375
376
377
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
378
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
379
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
380
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
381
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
382
383
384
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
385
    ),
386
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
387
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
388
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
389
390
391
392
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
393
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
394
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
395
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
396
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
397
398
399
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
400
    ),
401
402
403
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
404
    ),
405
406
407
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
408
    ),
409
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
410
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
411
412
413
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
414
    ),
415
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
416
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
417
418
419
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
420
    ),
421
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
422
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
423
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
424
425
426
427
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
428
    "Ovis": ("ovis", "Ovis"),
429
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
430
431
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
432
433
434
435
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
436
437
438
439
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
440
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
441
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
442
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
443
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
444
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
445
446
447
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
448
    ),
449
450
451
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
452
    ),
453
454
455
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
456
    ),
457
458
459
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
460
    ),
461
462
463
464
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
465
466
467
468
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
469
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
470
471
472
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
473
    ),
474
475
476
477
478
479
480
481
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
482
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
483
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
484
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
485
486
487
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
488
    ),
489
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
490
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
491
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
492
    # [Encoder-decoder]
493
494
495
496
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
497
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
498
}
499
500

_SPECULATIVE_DECODING_MODELS = {
501
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
502
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
503
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
504
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
505
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
506
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
507
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
508
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
509
510
511
512
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
513
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
514
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
515
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
516
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
517
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
518
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
519
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
520
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
521
    "MedusaModel": ("medusa", "Medusa"),
522
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
523
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
524
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
525
526
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
527
528
529
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
530
}
531

532
_TRANSFORMERS_SUPPORTED_MODELS = {
533
534
535
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
536
537
538
539
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
540
541
542
}

_TRANSFORMERS_BACKEND_MODELS = {
543
    # Text generation models
544
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
563
    "TransformersForSequenceClassification": (
564
        "transformers",
565
        "TransformersForSequenceClassification",
566
    ),
567
    "TransformersMoEForSequenceClassification": (
568
        "transformers",
569
        "TransformersMoEForSequenceClassification",
570
    ),
571
572
573
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
574
    ),
575
}
576

577
_VLLM_MODELS = {
578
    **_TEXT_GENERATION_MODELS,
579
    **_EMBEDDING_MODELS,
580
    **_CROSS_ENCODER_MODELS,
581
    **_MULTIMODAL_MODELS,
582
    **_SPECULATIVE_DECODING_MODELS,
583
584
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
585
586
}

587
588
589
590
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
591
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
592

593
_PREVIOUSLY_SUPPORTED_MODELS = {
594
    "MotifForCausalLM": "0.10.2",
595
    "Phi3SmallForCausalLM": "0.9.2",
596
    "Phi4FlashForCausalLM": "0.10.2",
597
    "Phi4MultimodalForCausalLM": "0.12.0",
598
599
600
601
602
603
604
605
606
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
607

608

609
610
@dataclass(frozen=True)
class _ModelInfo:
611
    architecture: str
612
    is_text_generation_model: bool
613
    is_pooling_model: bool
614
    attn_type: AttnTypeStr
615
616
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
617
    supports_cross_encoding: bool
618
    supports_late_interaction: bool
619
    supports_multimodal: bool
620
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
621
    requires_raw_input_tokens: bool
622
    supports_multimodal_encoder_tp_data: bool
623
    supports_pp: bool
624
625
    has_inner_state: bool
    is_attention_free: bool
626
    is_hybrid: bool
627
    has_noops: bool
628
    supports_mamba_prefix_caching: bool
629
    supports_transcription: bool
630
    supports_transcription_only: bool
631
632

    @staticmethod
633
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
634
        return _ModelInfo(
635
            architecture=model.__name__,
636
            is_text_generation_model=is_text_generation_model(model),
637
            is_pooling_model=is_pooling_model(model),
638
639
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
640
            attn_type=get_attn_type(model),
641
            supports_cross_encoding=supports_cross_encoding(model),
642
            supports_late_interaction=supports_late_interaction(model),
643
            supports_multimodal=supports_multimodal(model),
644
645
646
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
647
            requires_raw_input_tokens=requires_raw_input_tokens(model),
648
649
650
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
651
            supports_pp=supports_pp(model),
652
653
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
654
            is_hybrid=is_hybrid(model),
655
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
656
            supports_transcription=supports_transcription(model),
657
658
659
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
660
            has_noops=has_noops(model),
661
        )
662
663


664
665
666
667
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
668

669
    @abstractmethod
670
    def load_model_cls(self) -> type[nn.Module]:
671
        raise NotImplementedError
672
673


674
675
676
677
678
679
680
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
681
    model_cls: type[nn.Module]
682
683

    @staticmethod
684
    def from_model_cls(model_cls: type[nn.Module]):
685
686
687
688
689
690
691
692
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

693
    def load_model_cls(self) -> type[nn.Module]:
694
695
696
697
698
699
700
701
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
702

703
704
705
    module_name: str
    class_name: str

706
707
708
709
710
711
712
713
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

714
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
715
716
        try:
            try:
717
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
718
719
720
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
721
                logger.debug(
722
                    "Cached model info file for class %s.%s not found",
723
724
725
                    self.module_name,
                    self.class_name,
                )
726
727
728
                return None

            if mi_dict["hash"] != module_hash:
729
                logger.debug(
730
                    "Cached model info file for class %s.%s is stale",
731
732
733
                    self.module_name,
                    self.class_name,
                )
734
735
736
737
738
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
739
            logger.debug(
740
                "Cached model info for class %s.%s error. ",
741
742
743
                self.module_name,
                self.class_name,
            )
744
745
            return None

746
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
747
748
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
749

750
751
752
753
754
755
756
757
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
758
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
759
760
761
762
763
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
764
    def inspect_model_cls(self) -> _ModelInfo:
765
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
766
        module_hash = None
767

768
769
        if model_path.exists():
            with open(model_path, "rb") as f:
770
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
771
772
773

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
774
                logger.debug(
775
                    "Loaded model info for class %s.%s from cache",
776
777
778
                    self.module_name,
                    self.class_name,
                )
779
780
                return mi
            else:
781
                logger.debug(
782
                    "Cache model info for class %s.%s miss. Loading model instead.",
783
784
785
                    self.module_name,
                    self.class_name,
                )
786
787
788

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
789
790
791
792
793
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
794
795

        # save cache file
796
797
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
798
799

        return mi
800

801
    def load_model_cls(self) -> type[nn.Module]:
802
803
804
805
806
807
808
809
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
810
) -> type[nn.Module] | None:
811
    from vllm.platforms import current_platform
812

813
    current_platform.verify_model_arch(model_arch)
814
815
816
    try:
        return model.load_model_cls()
    except Exception:
817
        logger.exception("Error in loading model architecture '%s'", model_arch)
818
        return None
819
820


821
822
823
824
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
825
) -> _ModelInfo | None:
826
827
828
    try:
        return model.inspect_model_cls()
    except Exception:
829
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
830
        return None
831
832


833
834
835
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
836
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
837

838
    def get_supported_archs(self) -> Set[str]:
839
        return self.models.keys()
840

841
842
843
    def register_model(
        self,
        model_arch: str,
844
        model_cls: type[nn.Module] | str,
845
    ) -> None:
846
847
848
        """
        Register an external model to be used in vLLM.

849
        `model_cls` can be either:
850

851
        - A [`torch.nn.Module`][] class directly referencing the model.
852
        - A string in the format `<module>:<class>` which can be used to
853
854
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
855
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
856
        """
857
858
859
860
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

861
        if model_arch in self.models:
862
863
            logger.warning(
                "Model architecture %s is already registered, and will be "
864
865
866
867
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
868
869
870
871
872
873

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
874

875
            model = _LazyRegisteredModel(*split_str)
876
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
877
            model = _RegisteredModel.from_model_cls(model_cls)
878
        else:
879
880
881
882
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
883
            raise TypeError(msg)
884

885
        self.models[model_arch] = model
886

887
    def _raise_for_unsupported(self, architectures: list[str]):
888
        all_supported_archs = self.get_supported_archs()
889

890
891
892
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
893
894
                "to be inspected. Please check the logs for more details."
            )
895

896
897
898
899
900
901
902
903
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
904
905
                    "use this model architecture."
                )
906

907
908
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
909
910
            f"Supported architectures: {all_supported_archs}"
        )
911

912
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
913
914
        if model_arch not in self.models:
            return None
915

916
        return _try_load_model_cls(model_arch, self.models[model_arch])
917

918
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
919
920
        if model_arch not in self.models:
            return None
921

922
923
924
925
926
927
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
928
    ) -> str | None:
929
930
931
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

932
933
934
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
951
                        trust_remote_code=model_config.trust_remote_code,
952
953
954
955
956
957
958
959
960
961
962
963
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
964
                        trust_remote_code=model_config.trust_remote_code,
965
966
967
968
969
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
970
                if model_config.model_impl != "transformers":
971
972
973
974
975
976
977
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
978
979
                    "'auto_map' (relevant if the model is custom)."
                )
980
981

        if not model_module.is_backend_compatible():
982
            if model_config.model_impl != "transformers":
983
                return None
984

985
986
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
987
988
                "is not compatible with vLLM."
            )
989

990
        return model_config._get_transformers_backend_cls()
991

992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1017

1018
1019
    def inspect_model_cls(
        self,
1020
        architectures: str | list[str],
1021
        model_config: ModelConfig,
1022
    ) -> tuple[_ModelInfo, str]:
1023
1024
        if isinstance(architectures, str):
            architectures = [architectures]
1025
1026
        if not architectures:
            raise ValueError("No model architectures are specified")
1027
1028

        # Require transformers impl
1029
        if model_config.model_impl == "transformers":
1030
            arch = self._try_resolve_transformers(architectures[0], model_config)
1031
1032
1033
1034
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1035
        elif model_config.model_impl == "terratorch":
1036
1037
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1038

1039
        # Fallback to transformers impl (after resolving convert_type)
1040
1041
1042
1043
1044
1045
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1046
1047
1048
1049
1050
1051
1052
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1053
            model_info = self._try_inspect_model_cls(normalized_arch)
1054
            if model_info is not None:
1055
                return (model_info, arch)
1056

1057
        # Fallback to transformers impl (before resolving runner_type)
1058
1059
1060
1061
1062
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1063
1064
1065
1066
1067
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1068
        return self._raise_for_unsupported(architectures)
1069

1070
1071
    def resolve_model_cls(
        self,
1072
        architectures: str | list[str],
1073
        model_config: ModelConfig,
1074
    ) -> tuple[type[nn.Module], str]:
1075
1076
        if isinstance(architectures, str):
            architectures = [architectures]
1077
1078
        if not architectures:
            raise ValueError("No model architectures are specified")
1079
1080

        # Require transformers impl
1081
        if model_config.model_impl == "transformers":
1082
            arch = self._try_resolve_transformers(architectures[0], model_config)
1083
1084
1085
1086
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1087
        elif model_config.model_impl == "terratorch":
1088
1089
1090
1091
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1092

1093
        # Fallback to transformers impl (after resolving convert_type)
1094
1095
1096
1097
1098
1099
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1100
1101
1102
1103
1104
1105
1106
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1107
            model_cls = self._try_load_model_cls(normalized_arch)
1108
1109
            if model_cls is not None:
                return (model_cls, arch)
1110

1111
        # Fallback to transformers impl (before resolving runner_type)
1112
1113
1114
1115
1116
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1117
1118
1119
1120
1121
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1122
        return self._raise_for_unsupported(architectures)
1123

1124
1125
    def is_text_generation_model(
        self,
1126
        architectures: str | list[str],
1127
        model_config: ModelConfig,
1128
    ) -> bool:
1129
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1130
        return model_cls.is_text_generation_model
1131

1132
    def is_pooling_model(
1133
        self,
1134
        architectures: str | list[str],
1135
        model_config: ModelConfig,
1136
    ) -> bool:
1137
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1138
        return model_cls.is_pooling_model
1139

1140
1141
    def is_cross_encoder_model(
        self,
1142
        architectures: str | list[str],
1143
        model_config: ModelConfig,
1144
    ) -> bool:
1145
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1146
        return model_cls.supports_cross_encoding
1147

1148
1149
    def is_multimodal_model(
        self,
1150
        architectures: str | list[str],
1151
        model_config: ModelConfig,
1152
    ) -> bool:
1153
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1154
        return model_cls.supports_multimodal
1155

1156
    def is_multimodal_raw_input_only_model(
1157
        self,
1158
        architectures: str | list[str],
1159
        model_config: ModelConfig,
1160
    ) -> bool:
1161
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1162
        return model_cls.supports_multimodal_raw_input_only
1163

1164
1165
    def is_pp_supported_model(
        self,
1166
        architectures: str | list[str],
1167
        model_config: ModelConfig,
1168
    ) -> bool:
1169
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1170
        return model_cls.supports_pp
1171

1172
1173
    def model_has_inner_state(
        self,
1174
        architectures: str | list[str],
1175
        model_config: ModelConfig,
1176
    ) -> bool:
1177
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1178
        return model_cls.has_inner_state
1179

1180
1181
    def is_attention_free_model(
        self,
1182
        architectures: str | list[str],
1183
        model_config: ModelConfig,
1184
    ) -> bool:
1185
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1186
        return model_cls.is_attention_free
1187

1188
1189
    def is_hybrid_model(
        self,
1190
        architectures: str | list[str],
1191
        model_config: ModelConfig,
1192
    ) -> bool:
1193
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1194
1195
        return model_cls.is_hybrid

1196
1197
    def is_noops_model(
        self,
1198
        architectures: str | list[str],
1199
        model_config: ModelConfig,
1200
    ) -> bool:
1201
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1202
1203
        return model_cls.has_noops

1204
1205
    def is_transcription_model(
        self,
1206
        architectures: str | list[str],
1207
        model_config: ModelConfig,
1208
    ) -> bool:
1209
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1210
1211
        return model_cls.supports_transcription

1212
1213
    def is_transcription_only_model(
        self,
1214
        architectures: str | list[str],
1215
        model_config: ModelConfig,
1216
    ) -> bool:
1217
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1218
1219
        return model_cls.supports_transcription_only

1220

1221
1222
1223
1224
1225
1226
1227
1228
1229
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1230
1231
1232
1233
1234

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1235
1236
1237
1238
1239
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1240
        # `cloudpickle` allows pickling lambda functions directly
1241
        import cloudpickle
1242

1243
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1244
1245
1246

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1247
1248
1249
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1250
1251
1252
1253
1254
1255

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1256
1257
1258
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1259

1260
        with open(output_filepath, "rb") as f:
1261
1262
1263
1264
1265
1266
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1267

1268
1269
1270
1271
1272
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1273
1274
1275

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1276
1277
1278


if __name__ == "__main__":
1279
    _run()