registry.py 49.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
84
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
85
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
86
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
87
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
88
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
92
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
93
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
94
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
95
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
98
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
99
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
100
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
101
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
102
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
103
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
104
105
106
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
107
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
108
109
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
110
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
111
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
112
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
116
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
117
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
118
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
119
120
121
122
123
124
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
125
126
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
127
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
128
129
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
130
131
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
132
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
133
134
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
135
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
136
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
137
138
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
139
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
140
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
141
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
142
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
143
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
144
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
145
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
147
148
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
149
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
150
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
151
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
152
153
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
154
155
156
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
157
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
158
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
159
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
160
161
162
163
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
164
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
165
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
166
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
167
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
168
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
170
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
171
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
173
174
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
175
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
176
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
177
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
178
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
179
180
181
182
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
183
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
184
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
185
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
186
187
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
188
189
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
190
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
191
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
192
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
193
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
194
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
195
196
197
198
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
199
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
202
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
203
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
204
205
206
}

_EMBEDDING_MODELS = {
207
    # [Text-only]
208
    "BertModel": ("bert", "BertEmbeddingModel"),
209
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
210
    "HF_ColBERT": ("colbert", "ColBERTModel"),
211
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
212
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
213
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
214
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
215
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
216
    "GritLM": ("gritlm", "GritLM"),
217
218
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
219
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
220
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
221
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
222
    "LlamaModel": ("llama", "LlamaForCausalLM"),
223
224
    **{
        # Multiple models share the same architecture, so we include them all
225
226
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
227
228
        if arch == "LlamaForCausalLM"
    },
229
    "MistralModel": ("llama", "LlamaForCausalLM"),
230
    "ModernBertModel": ("modernbert", "ModernBertModel"),
231
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
232
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
233
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
234
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
235
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
236
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
237
238
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
239
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
240
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
241
242
243
244
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
245
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
246
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
247
    # [Multimodal]
248
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
249
250
251
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
252
    ),
253
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
254
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
255
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
256
257
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
258
    # models for the time being.
259
260
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
261
262
}

263
264
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
265
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
266
267
268
269
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
270
271
272
273
274
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
275
276
277
278
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
279
280
281
282
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
283
284
285
286
287
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
288
289
}

290
_MULTIMODAL_MODELS = {
291
    # [Decoder-only]
292
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
293
294
295
296
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
297
298
299
300
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
301
302
303
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
304
    ),
305
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
306
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
307
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
308
309
310
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
311
    ),
312
313
314
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
315
    ),
316
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
317
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
318
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
319
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
320
321
322
323
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
324
325
326
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
327
    ),
328
329
330
331
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
332
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
333
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
334
335
336
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
337
    ),
338
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
339
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
340
341
342
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
343
344
345
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
346
    ),
347
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
348
349
350
351
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
352
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
353
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
354
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
355
356
357
358
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
359
360
361
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
362
    ),
363
364
365
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
366
    ),
zxy's avatar
zxy committed
367
368
369
370
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
371
372
373
374
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
375
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
376
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
377
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
378
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
379
380
381
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
382
    ),
383
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
384
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
385
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
386
387
388
389
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
390
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
391
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
392
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
393
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
394
395
396
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
397
    ),
398
399
400
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
401
    ),
402
403
404
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
405
    ),
406
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
407
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
408
409
410
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
411
    ),
412
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
413
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
414
415
416
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
417
    ),
418
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
419
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
420
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
421
422
423
424
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
425
    "Ovis": ("ovis", "Ovis"),
426
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
427
428
429
430
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
431
432
433
434
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
435
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
436
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
437
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
438
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
439
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
440
441
442
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
443
    ),
444
445
446
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
447
    ),
448
449
450
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
451
    ),
452
453
454
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
455
    ),
456
457
458
459
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
460
461
462
463
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
464
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
465
466
467
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
468
    ),
469
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
470
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
471
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
472
473
474
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
475
    ),
476
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
477
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
478
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
479
    # [Encoder-decoder]
480
481
482
483
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
484
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
485
}
486
487

_SPECULATIVE_DECODING_MODELS = {
488
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
489
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
490
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
491
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
492
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
493
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
494
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
495
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
496
497
498
499
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
500
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
501
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
502
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
503
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
504
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
505
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
506
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
507
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
508
    "MedusaModel": ("medusa", "Medusa"),
509
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
510
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
511
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
512
513
514
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
515
}
516

517
_TRANSFORMERS_SUPPORTED_MODELS = {
518
519
520
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
521
522
523
524
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
525
526
527
}

_TRANSFORMERS_BACKEND_MODELS = {
528
    # Text generation models
529
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
548
    "TransformersForSequenceClassification": (
549
        "transformers",
550
        "TransformersForSequenceClassification",
551
    ),
552
    "TransformersMoEForSequenceClassification": (
553
        "transformers",
554
        "TransformersMoEForSequenceClassification",
555
    ),
556
557
558
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
559
    ),
560
}
561

562
_VLLM_MODELS = {
563
    **_TEXT_GENERATION_MODELS,
564
    **_EMBEDDING_MODELS,
565
    **_CROSS_ENCODER_MODELS,
566
    **_MULTIMODAL_MODELS,
567
    **_SPECULATIVE_DECODING_MODELS,
568
569
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
570
571
}

572
573
574
575
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
576
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
577

578
_PREVIOUSLY_SUPPORTED_MODELS = {
579
    "MotifForCausalLM": "0.10.2",
580
    "Phi3SmallForCausalLM": "0.9.2",
581
    "Phi4FlashForCausalLM": "0.10.2",
582
    "Phi4MultimodalForCausalLM": "0.12.0",
583
584
585
586
587
588
589
590
591
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
592

593

594
595
@dataclass(frozen=True)
class _ModelInfo:
596
    architecture: str
597
    is_text_generation_model: bool
598
    is_pooling_model: bool
599
    attn_type: AttnTypeStr
600
601
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
602
    supports_cross_encoding: bool
603
    supports_late_interaction: bool
604
    supports_multimodal: bool
605
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
606
    requires_raw_input_tokens: bool
607
    supports_multimodal_encoder_tp_data: bool
608
    supports_pp: bool
609
610
    has_inner_state: bool
    is_attention_free: bool
611
    is_hybrid: bool
612
    has_noops: bool
613
    supports_mamba_prefix_caching: bool
614
    supports_transcription: bool
615
    supports_transcription_only: bool
616
617

    @staticmethod
618
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
619
        return _ModelInfo(
620
            architecture=model.__name__,
621
            is_text_generation_model=is_text_generation_model(model),
622
            is_pooling_model=is_pooling_model(model),
623
624
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
625
            attn_type=get_attn_type(model),
626
            supports_cross_encoding=supports_cross_encoding(model),
627
            supports_late_interaction=supports_late_interaction(model),
628
            supports_multimodal=supports_multimodal(model),
629
630
631
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
632
            requires_raw_input_tokens=requires_raw_input_tokens(model),
633
634
635
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
636
            supports_pp=supports_pp(model),
637
638
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
639
            is_hybrid=is_hybrid(model),
640
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
641
            supports_transcription=supports_transcription(model),
642
643
644
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
645
            has_noops=has_noops(model),
646
        )
647
648


649
650
651
652
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
653

654
    @abstractmethod
655
    def load_model_cls(self) -> type[nn.Module]:
656
        raise NotImplementedError
657
658


659
660
661
662
663
664
665
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
666
    model_cls: type[nn.Module]
667
668

    @staticmethod
669
    def from_model_cls(model_cls: type[nn.Module]):
670
671
672
673
674
675
676
677
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

678
    def load_model_cls(self) -> type[nn.Module]:
679
680
681
682
683
684
685
686
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
687

688
689
690
    module_name: str
    class_name: str

691
692
693
694
695
696
697
698
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

699
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
700
701
        try:
            try:
702
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
703
704
705
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
706
                logger.debug(
707
                    "Cached model info file for class %s.%s not found",
708
709
710
                    self.module_name,
                    self.class_name,
                )
711
712
713
                return None

            if mi_dict["hash"] != module_hash:
714
                logger.debug(
715
                    "Cached model info file for class %s.%s is stale",
716
717
718
                    self.module_name,
                    self.class_name,
                )
719
720
721
722
723
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
724
            logger.debug(
725
                "Cached model info for class %s.%s error. ",
726
727
728
                self.module_name,
                self.class_name,
            )
729
730
            return None

731
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
732
733
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
734

735
736
737
738
739
740
741
742
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
743
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
744
745
746
747
748
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
749
    def inspect_model_cls(self) -> _ModelInfo:
750
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
751
        module_hash = None
752

753
754
        if model_path.exists():
            with open(model_path, "rb") as f:
755
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
756
757
758

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
759
                logger.debug(
760
                    "Loaded model info for class %s.%s from cache",
761
762
763
                    self.module_name,
                    self.class_name,
                )
764
765
                return mi
            else:
766
                logger.debug(
767
                    "Cache model info for class %s.%s miss. Loading model instead.",
768
769
770
                    self.module_name,
                    self.class_name,
                )
771
772
773

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
774
775
776
777
778
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
779
780

        # save cache file
781
782
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
783
784

        return mi
785

786
    def load_model_cls(self) -> type[nn.Module]:
787
788
789
790
791
792
793
794
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
795
) -> type[nn.Module] | None:
796
    from vllm.platforms import current_platform
797

798
    current_platform.verify_model_arch(model_arch)
799
800
801
    try:
        return model.load_model_cls()
    except Exception:
802
        logger.exception("Error in loading model architecture '%s'", model_arch)
803
        return None
804
805


806
807
808
809
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
810
) -> _ModelInfo | None:
811
812
813
    try:
        return model.inspect_model_cls()
    except Exception:
814
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
815
        return None
816
817


818
819
820
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
821
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
822

823
    def get_supported_archs(self) -> Set[str]:
824
        return self.models.keys()
825

826
827
828
    def register_model(
        self,
        model_arch: str,
829
        model_cls: type[nn.Module] | str,
830
    ) -> None:
831
832
833
        """
        Register an external model to be used in vLLM.

834
        `model_cls` can be either:
835

836
        - A [`torch.nn.Module`][] class directly referencing the model.
837
        - A string in the format `<module>:<class>` which can be used to
838
839
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
840
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
841
        """
842
843
844
845
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

846
        if model_arch in self.models:
847
848
            logger.warning(
                "Model architecture %s is already registered, and will be "
849
850
851
852
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
853
854
855
856
857
858

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
859

860
            model = _LazyRegisteredModel(*split_str)
861
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
862
            model = _RegisteredModel.from_model_cls(model_cls)
863
        else:
864
865
866
867
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
868
            raise TypeError(msg)
869

870
        self.models[model_arch] = model
871

872
    def _raise_for_unsupported(self, architectures: list[str]):
873
        all_supported_archs = self.get_supported_archs()
874

875
876
877
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
878
879
                "to be inspected. Please check the logs for more details."
            )
880

881
882
883
884
885
886
887
888
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
889
890
                    "use this model architecture."
                )
891

892
893
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
894
895
            f"Supported architectures: {all_supported_archs}"
        )
896

897
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
898
899
        if model_arch not in self.models:
            return None
900

901
        return _try_load_model_cls(model_arch, self.models[model_arch])
902

903
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
904
905
        if model_arch not in self.models:
            return None
906

907
908
909
910
911
912
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
913
    ) -> str | None:
914
915
916
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

917
918
919
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
936
                        trust_remote_code=model_config.trust_remote_code,
937
938
939
940
941
942
943
944
945
946
947
948
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
949
                        trust_remote_code=model_config.trust_remote_code,
950
951
952
953
954
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
955
                if model_config.model_impl != "transformers":
956
957
958
959
960
961
962
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
963
964
                    "'auto_map' (relevant if the model is custom)."
                )
965
966

        if not model_module.is_backend_compatible():
967
            if model_config.model_impl != "transformers":
968
                return None
969

970
971
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
972
973
                "is not compatible with vLLM."
            )
974

975
        return model_config._get_transformers_backend_cls()
976

977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1002

1003
1004
    def inspect_model_cls(
        self,
1005
        architectures: str | list[str],
1006
        model_config: ModelConfig,
1007
    ) -> tuple[_ModelInfo, str]:
1008
1009
        if isinstance(architectures, str):
            architectures = [architectures]
1010
1011
        if not architectures:
            raise ValueError("No model architectures are specified")
1012
1013

        # Require transformers impl
1014
        if model_config.model_impl == "transformers":
1015
            arch = self._try_resolve_transformers(architectures[0], model_config)
1016
1017
1018
1019
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1020
        elif model_config.model_impl == "terratorch":
1021
1022
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1023

1024
        # Fallback to transformers impl (after resolving convert_type)
1025
1026
1027
1028
1029
1030
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1031
1032
1033
1034
1035
1036
1037
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1038
            model_info = self._try_inspect_model_cls(normalized_arch)
1039
            if model_info is not None:
1040
                return (model_info, arch)
1041

1042
        # Fallback to transformers impl (before resolving runner_type)
1043
1044
1045
1046
1047
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1048
1049
1050
1051
1052
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1053
        return self._raise_for_unsupported(architectures)
1054

1055
1056
    def resolve_model_cls(
        self,
1057
        architectures: str | list[str],
1058
        model_config: ModelConfig,
1059
    ) -> tuple[type[nn.Module], str]:
1060
1061
        if isinstance(architectures, str):
            architectures = [architectures]
1062
1063
        if not architectures:
            raise ValueError("No model architectures are specified")
1064
1065

        # Require transformers impl
1066
        if model_config.model_impl == "transformers":
1067
            arch = self._try_resolve_transformers(architectures[0], model_config)
1068
1069
1070
1071
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1072
        elif model_config.model_impl == "terratorch":
1073
1074
1075
1076
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1077

1078
        # Fallback to transformers impl (after resolving convert_type)
1079
1080
1081
1082
1083
1084
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1085
1086
1087
1088
1089
1090
1091
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1092
            model_cls = self._try_load_model_cls(normalized_arch)
1093
1094
            if model_cls is not None:
                return (model_cls, arch)
1095

1096
        # Fallback to transformers impl (before resolving runner_type)
1097
1098
1099
1100
1101
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1102
1103
1104
1105
1106
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1107
        return self._raise_for_unsupported(architectures)
1108

1109
1110
    def is_text_generation_model(
        self,
1111
        architectures: str | list[str],
1112
        model_config: ModelConfig,
1113
    ) -> bool:
1114
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1115
        return model_cls.is_text_generation_model
1116

1117
    def is_pooling_model(
1118
        self,
1119
        architectures: str | list[str],
1120
        model_config: ModelConfig,
1121
    ) -> bool:
1122
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1123
        return model_cls.is_pooling_model
1124

1125
1126
    def is_cross_encoder_model(
        self,
1127
        architectures: str | list[str],
1128
        model_config: ModelConfig,
1129
    ) -> bool:
1130
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1131
        return model_cls.supports_cross_encoding
1132

1133
1134
    def is_multimodal_model(
        self,
1135
        architectures: str | list[str],
1136
        model_config: ModelConfig,
1137
    ) -> bool:
1138
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1139
        return model_cls.supports_multimodal
1140

1141
    def is_multimodal_raw_input_only_model(
1142
        self,
1143
        architectures: str | list[str],
1144
        model_config: ModelConfig,
1145
    ) -> bool:
1146
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1147
        return model_cls.supports_multimodal_raw_input_only
1148

1149
1150
    def is_pp_supported_model(
        self,
1151
        architectures: str | list[str],
1152
        model_config: ModelConfig,
1153
    ) -> bool:
1154
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1155
        return model_cls.supports_pp
1156

1157
1158
    def model_has_inner_state(
        self,
1159
        architectures: str | list[str],
1160
        model_config: ModelConfig,
1161
    ) -> bool:
1162
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1163
        return model_cls.has_inner_state
1164

1165
1166
    def is_attention_free_model(
        self,
1167
        architectures: str | list[str],
1168
        model_config: ModelConfig,
1169
    ) -> bool:
1170
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1171
        return model_cls.is_attention_free
1172

1173
1174
    def is_hybrid_model(
        self,
1175
        architectures: str | list[str],
1176
        model_config: ModelConfig,
1177
    ) -> bool:
1178
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1179
1180
        return model_cls.is_hybrid

1181
1182
    def is_noops_model(
        self,
1183
        architectures: str | list[str],
1184
        model_config: ModelConfig,
1185
    ) -> bool:
1186
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1187
1188
        return model_cls.has_noops

1189
1190
    def is_transcription_model(
        self,
1191
        architectures: str | list[str],
1192
        model_config: ModelConfig,
1193
    ) -> bool:
1194
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1195
1196
        return model_cls.supports_transcription

1197
1198
    def is_transcription_only_model(
        self,
1199
        architectures: str | list[str],
1200
        model_config: ModelConfig,
1201
    ) -> bool:
1202
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1203
1204
        return model_cls.supports_transcription_only

1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1215
1216
1217
1218
1219

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1220
1221
1222
1223
1224
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1225
        # `cloudpickle` allows pickling lambda functions directly
1226
        import cloudpickle
1227

1228
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1229
1230
1231

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1232
1233
1234
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1235
1236
1237
1238
1239
1240

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1241
1242
1243
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1244

1245
        with open(output_filepath, "rb") as f:
1246
1247
1248
1249
1250
1251
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1252

1253
1254
1255
1256
1257
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1258
1259
1260

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1261
1262
1263


if __name__ == "__main__":
1264
    _run()