registry.py 50 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
84
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
85
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
86
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
87
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
88
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
92
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
93
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
94
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
95
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
98
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
99
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
100
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
101
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
102
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
103
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
104
105
106
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
107
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
108
109
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
110
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
111
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
112
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
116
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
117
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
118
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
119
120
121
122
123
124
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
125
126
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
127
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
128
129
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
130
131
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
132
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
133
134
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
135
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
136
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
137
138
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
139
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
140
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
141
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
142
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
143
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
144
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
145
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
147
148
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
149
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
150
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
151
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
152
153
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
154
155
156
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
157
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
158
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
159
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
160
161
162
163
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
164
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
165
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
166
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
167
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
168
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
170
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
171
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
173
174
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
175
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
176
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
177
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
178
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
179
180
181
182
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
183
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
184
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
185
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
186
187
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
188
189
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
190
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
191
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
192
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
193
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
194
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
195
196
197
198
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
199
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
202
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
203
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
204
205
206
}

_EMBEDDING_MODELS = {
207
    # [Text-only]
208
    "BertModel": ("bert", "BertEmbeddingModel"),
209
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
210
    "HF_ColBERT": ("colbert", "ColBERTModel"),
211
212
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
213
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
214
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
215
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
216
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
217
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
218
    "GritLM": ("gritlm", "GritLM"),
219
220
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
221
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
222
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
223
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
224
    "LlamaModel": ("llama", "LlamaForCausalLM"),
225
226
    **{
        # Multiple models share the same architecture, so we include them all
227
228
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
229
230
        if arch == "LlamaForCausalLM"
    },
231
    "MistralModel": ("llama", "LlamaForCausalLM"),
232
    "ModernBertModel": ("modernbert", "ModernBertModel"),
233
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
234
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
235
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
236
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
237
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
238
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
239
240
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
241
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
242
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
243
244
245
246
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
247
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
248
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
249
    # [Multimodal]
250
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
251
252
253
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
254
    ),
255
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
256
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
257
258
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
259
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
260
261
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
262
    # models for the time being.
263
264
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
265
266
}

267
268
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
269
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
270
271
272
273
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
274
275
276
277
278
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
279
280
281
282
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
283
284
285
286
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
287
288
289
290
291
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
292
293
}

294
_MULTIMODAL_MODELS = {
295
    # [Decoder-only]
296
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
297
298
299
300
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
301
302
303
304
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
305
306
307
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
308
    ),
309
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
310
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
311
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
312
313
314
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
315
    ),
316
317
318
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
319
    ),
320
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
321
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
322
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
323
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
324
325
326
327
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
328
329
330
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
331
    ),
332
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
333
334
335
336
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
337
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
338
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
339
340
341
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
342
    ),
343
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
344
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
345
346
347
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
348
349
350
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
351
    ),
352
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
353
354
355
356
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
357
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
358
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
359
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
360
361
362
363
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
364
365
366
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
367
    ),
368
369
370
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
371
    ),
zxy's avatar
zxy committed
372
373
374
375
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
376
377
378
379
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
380
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
381
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
382
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
383
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
384
385
386
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
387
    ),
388
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
389
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
390
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
391
392
393
394
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
395
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
396
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
397
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
398
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
399
400
401
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
402
    ),
403
404
405
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
406
    ),
407
408
409
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
410
    ),
411
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
412
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
413
414
415
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
416
    ),
417
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
418
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
419
420
421
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
422
    ),
423
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
424
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
425
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
426
427
428
429
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
430
    "Ovis": ("ovis", "Ovis"),
431
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
432
433
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
434
435
436
437
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
438
439
440
441
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
442
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
443
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
444
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
445
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
446
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
447
448
449
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
450
    ),
451
452
453
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
454
    ),
455
456
457
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
458
    ),
459
460
461
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
462
    ),
463
464
465
466
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
467
468
469
470
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
471
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
472
473
474
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
475
    ),
476
477
478
479
480
481
482
483
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
484
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
485
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
486
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
487
488
489
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
490
    ),
491
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
492
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
493
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
494
    # [Encoder-decoder]
495
496
497
498
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
499
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
500
}
501
502

_SPECULATIVE_DECODING_MODELS = {
503
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
504
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
505
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
506
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
507
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
508
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
509
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
510
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
511
512
513
514
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
515
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
516
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
517
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
518
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
519
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
520
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
521
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
522
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
523
    "MedusaModel": ("medusa", "Medusa"),
524
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
525
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
526
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
527
528
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
529
530
531
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
532
}
533

534
_TRANSFORMERS_SUPPORTED_MODELS = {
535
536
537
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
538
539
540
541
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
542
543
544
}

_TRANSFORMERS_BACKEND_MODELS = {
545
    # Text generation models
546
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
565
    "TransformersForSequenceClassification": (
566
        "transformers",
567
        "TransformersForSequenceClassification",
568
    ),
569
    "TransformersMoEForSequenceClassification": (
570
        "transformers",
571
        "TransformersMoEForSequenceClassification",
572
    ),
573
574
575
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
576
    ),
577
}
578

579
_VLLM_MODELS = {
580
    **_TEXT_GENERATION_MODELS,
581
    **_EMBEDDING_MODELS,
582
    **_CROSS_ENCODER_MODELS,
583
    **_MULTIMODAL_MODELS,
584
    **_SPECULATIVE_DECODING_MODELS,
585
586
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
587
588
}

589
590
591
592
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
593
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
594

595
_PREVIOUSLY_SUPPORTED_MODELS = {
596
    "MotifForCausalLM": "0.10.2",
597
    "Phi3SmallForCausalLM": "0.9.2",
598
    "Phi4FlashForCausalLM": "0.10.2",
599
    "Phi4MultimodalForCausalLM": "0.12.0",
600
601
602
603
604
605
606
607
608
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
609

610

611
612
@dataclass(frozen=True)
class _ModelInfo:
613
    architecture: str
614
    is_text_generation_model: bool
615
    is_pooling_model: bool
616
    attn_type: AttnTypeStr
617
618
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
619
    supports_cross_encoding: bool
620
    supports_late_interaction: bool
621
    supports_multimodal: bool
622
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
623
    requires_raw_input_tokens: bool
624
    supports_multimodal_encoder_tp_data: bool
625
    supports_pp: bool
626
627
    has_inner_state: bool
    is_attention_free: bool
628
    is_hybrid: bool
629
    has_noops: bool
630
    supports_mamba_prefix_caching: bool
631
    supports_transcription: bool
632
    supports_transcription_only: bool
633
634

    @staticmethod
635
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
636
        return _ModelInfo(
637
            architecture=model.__name__,
638
            is_text_generation_model=is_text_generation_model(model),
639
            is_pooling_model=is_pooling_model(model),
640
641
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
642
            attn_type=get_attn_type(model),
643
            supports_cross_encoding=supports_cross_encoding(model),
644
            supports_late_interaction=supports_late_interaction(model),
645
            supports_multimodal=supports_multimodal(model),
646
647
648
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
649
            requires_raw_input_tokens=requires_raw_input_tokens(model),
650
651
652
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
653
            supports_pp=supports_pp(model),
654
655
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
656
            is_hybrid=is_hybrid(model),
657
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
658
            supports_transcription=supports_transcription(model),
659
660
661
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
662
            has_noops=has_noops(model),
663
        )
664
665


666
667
668
669
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
670

671
    @abstractmethod
672
    def load_model_cls(self) -> type[nn.Module]:
673
        raise NotImplementedError
674
675


676
677
678
679
680
681
682
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
683
    model_cls: type[nn.Module]
684
685

    @staticmethod
686
    def from_model_cls(model_cls: type[nn.Module]):
687
688
689
690
691
692
693
694
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

695
    def load_model_cls(self) -> type[nn.Module]:
696
697
698
699
700
701
702
703
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
704

705
706
707
    module_name: str
    class_name: str

708
709
710
711
712
713
714
715
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

716
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
717
718
        try:
            try:
719
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
720
721
722
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
723
                logger.debug(
724
                    "Cached model info file for class %s.%s not found",
725
726
727
                    self.module_name,
                    self.class_name,
                )
728
729
730
                return None

            if mi_dict["hash"] != module_hash:
731
                logger.debug(
732
                    "Cached model info file for class %s.%s is stale",
733
734
735
                    self.module_name,
                    self.class_name,
                )
736
737
738
739
740
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
741
            logger.debug(
742
                "Cached model info for class %s.%s error. ",
743
744
745
                self.module_name,
                self.class_name,
            )
746
747
            return None

748
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
749
750
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
751

752
753
754
755
756
757
758
759
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
760
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
761
762
763
764
765
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
766
    def inspect_model_cls(self) -> _ModelInfo:
767
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
768
        module_hash = None
769

770
771
        if model_path.exists():
            with open(model_path, "rb") as f:
772
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
773
774
775

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
776
                logger.debug(
777
                    "Loaded model info for class %s.%s from cache",
778
779
780
                    self.module_name,
                    self.class_name,
                )
781
782
                return mi
            else:
783
                logger.debug(
784
                    "Cache model info for class %s.%s miss. Loading model instead.",
785
786
787
                    self.module_name,
                    self.class_name,
                )
788
789
790

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
791
792
793
794
795
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
796
797

        # save cache file
798
799
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
800
801

        return mi
802

803
    def load_model_cls(self) -> type[nn.Module]:
804
805
806
807
808
809
810
811
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
812
) -> type[nn.Module] | None:
813
    from vllm.platforms import current_platform
814

815
    current_platform.verify_model_arch(model_arch)
816
817
818
    try:
        return model.load_model_cls()
    except Exception:
819
        logger.exception("Error in loading model architecture '%s'", model_arch)
820
        return None
821
822


823
824
825
826
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
827
) -> _ModelInfo | None:
828
829
830
    try:
        return model.inspect_model_cls()
    except Exception:
831
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
832
        return None
833
834


835
836
837
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
838
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
839

840
    def get_supported_archs(self) -> Set[str]:
841
        return self.models.keys()
842

843
844
845
    def register_model(
        self,
        model_arch: str,
846
        model_cls: type[nn.Module] | str,
847
    ) -> None:
848
849
850
        """
        Register an external model to be used in vLLM.

851
        `model_cls` can be either:
852

853
        - A [`torch.nn.Module`][] class directly referencing the model.
854
        - A string in the format `<module>:<class>` which can be used to
855
856
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
857
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
858
        """
859
860
861
862
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

863
        if model_arch in self.models:
864
865
            logger.warning(
                "Model architecture %s is already registered, and will be "
866
867
868
869
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
870
871
872
873
874
875

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
876

877
            model = _LazyRegisteredModel(*split_str)
878
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
879
            model = _RegisteredModel.from_model_cls(model_cls)
880
        else:
881
882
883
884
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
885
            raise TypeError(msg)
886

887
        self.models[model_arch] = model
888

889
    def _raise_for_unsupported(self, architectures: list[str]):
890
        all_supported_archs = self.get_supported_archs()
891

892
893
894
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
895
896
                "to be inspected. Please check the logs for more details."
            )
897

898
899
900
901
902
903
904
905
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
906
907
                    "use this model architecture."
                )
908

909
910
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
911
912
            f"Supported architectures: {all_supported_archs}"
        )
913

914
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
915
916
        if model_arch not in self.models:
            return None
917

918
        return _try_load_model_cls(model_arch, self.models[model_arch])
919

920
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
921
922
        if model_arch not in self.models:
            return None
923

924
925
926
927
928
929
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
930
    ) -> str | None:
931
932
933
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

934
935
936
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
953
                        trust_remote_code=model_config.trust_remote_code,
954
955
956
957
958
959
960
961
962
963
964
965
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
966
                        trust_remote_code=model_config.trust_remote_code,
967
968
969
970
971
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
972
                if model_config.model_impl != "transformers":
973
974
975
976
977
978
979
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
980
981
                    "'auto_map' (relevant if the model is custom)."
                )
982
983

        if not model_module.is_backend_compatible():
984
            if model_config.model_impl != "transformers":
985
                return None
986

987
988
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
989
990
                "is not compatible with vLLM."
            )
991

992
        return model_config._get_transformers_backend_cls()
993

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1019

1020
1021
    def inspect_model_cls(
        self,
1022
        architectures: str | list[str],
1023
        model_config: ModelConfig,
1024
    ) -> tuple[_ModelInfo, str]:
1025
1026
        if isinstance(architectures, str):
            architectures = [architectures]
1027
1028
        if not architectures:
            raise ValueError("No model architectures are specified")
1029
1030

        # Require transformers impl
1031
        if model_config.model_impl == "transformers":
1032
            arch = self._try_resolve_transformers(architectures[0], model_config)
1033
1034
1035
1036
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1037
        elif model_config.model_impl == "terratorch":
1038
1039
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1040

1041
        # Fallback to transformers impl (after resolving convert_type)
1042
1043
1044
1045
1046
1047
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1048
1049
1050
1051
1052
1053
1054
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1055
            model_info = self._try_inspect_model_cls(normalized_arch)
1056
            if model_info is not None:
1057
                return (model_info, arch)
1058

1059
        # Fallback to transformers impl (before resolving runner_type)
1060
1061
1062
1063
1064
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1065
1066
1067
1068
1069
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1070
        return self._raise_for_unsupported(architectures)
1071

1072
1073
    def resolve_model_cls(
        self,
1074
        architectures: str | list[str],
1075
        model_config: ModelConfig,
1076
    ) -> tuple[type[nn.Module], str]:
1077
1078
        if isinstance(architectures, str):
            architectures = [architectures]
1079
1080
        if not architectures:
            raise ValueError("No model architectures are specified")
1081
1082

        # Require transformers impl
1083
        if model_config.model_impl == "transformers":
1084
            arch = self._try_resolve_transformers(architectures[0], model_config)
1085
1086
1087
1088
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1089
        elif model_config.model_impl == "terratorch":
1090
1091
1092
1093
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1094

1095
        # Fallback to transformers impl (after resolving convert_type)
1096
1097
1098
1099
1100
1101
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1102
1103
1104
1105
1106
1107
1108
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1109
            model_cls = self._try_load_model_cls(normalized_arch)
1110
1111
            if model_cls is not None:
                return (model_cls, arch)
1112

1113
        # Fallback to transformers impl (before resolving runner_type)
1114
1115
1116
1117
1118
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1119
1120
1121
1122
1123
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1124
        return self._raise_for_unsupported(architectures)
1125

1126
1127
    def is_text_generation_model(
        self,
1128
        architectures: str | list[str],
1129
        model_config: ModelConfig,
1130
    ) -> bool:
1131
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1132
        return model_cls.is_text_generation_model
1133

1134
    def is_pooling_model(
1135
        self,
1136
        architectures: str | list[str],
1137
        model_config: ModelConfig,
1138
    ) -> bool:
1139
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1140
        return model_cls.is_pooling_model
1141

1142
1143
    def is_cross_encoder_model(
        self,
1144
        architectures: str | list[str],
1145
        model_config: ModelConfig,
1146
    ) -> bool:
1147
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1148
        return model_cls.supports_cross_encoding
1149

1150
1151
    def is_multimodal_model(
        self,
1152
        architectures: str | list[str],
1153
        model_config: ModelConfig,
1154
    ) -> bool:
1155
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1156
        return model_cls.supports_multimodal
1157

1158
    def is_multimodal_raw_input_only_model(
1159
        self,
1160
        architectures: str | list[str],
1161
        model_config: ModelConfig,
1162
    ) -> bool:
1163
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1164
        return model_cls.supports_multimodal_raw_input_only
1165

1166
1167
    def is_pp_supported_model(
        self,
1168
        architectures: str | list[str],
1169
        model_config: ModelConfig,
1170
    ) -> bool:
1171
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1172
        return model_cls.supports_pp
1173

1174
1175
    def model_has_inner_state(
        self,
1176
        architectures: str | list[str],
1177
        model_config: ModelConfig,
1178
    ) -> bool:
1179
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1180
        return model_cls.has_inner_state
1181

1182
1183
    def is_attention_free_model(
        self,
1184
        architectures: str | list[str],
1185
        model_config: ModelConfig,
1186
    ) -> bool:
1187
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1188
        return model_cls.is_attention_free
1189

1190
1191
    def is_hybrid_model(
        self,
1192
        architectures: str | list[str],
1193
        model_config: ModelConfig,
1194
    ) -> bool:
1195
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1196
1197
        return model_cls.is_hybrid

1198
1199
    def is_noops_model(
        self,
1200
        architectures: str | list[str],
1201
        model_config: ModelConfig,
1202
    ) -> bool:
1203
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1204
1205
        return model_cls.has_noops

1206
1207
    def is_transcription_model(
        self,
1208
        architectures: str | list[str],
1209
        model_config: ModelConfig,
1210
    ) -> bool:
1211
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1212
1213
        return model_cls.supports_transcription

1214
1215
    def is_transcription_only_model(
        self,
1216
        architectures: str | list[str],
1217
        model_config: ModelConfig,
1218
    ) -> bool:
1219
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1220
1221
        return model_cls.supports_transcription_only

1222

1223
1224
1225
1226
1227
1228
1229
1230
1231
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1232
1233
1234
1235
1236

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1237
1238
1239
1240
1241
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1242
        # `cloudpickle` allows pickling lambda functions directly
1243
        import cloudpickle
1244

1245
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1246
1247
1248

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1249
1250
1251
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1252
1253
1254
1255
1256
1257

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1258
1259
1260
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1261

1262
        with open(output_filepath, "rb") as f:
1263
1264
1265
1266
1267
1268
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1269

1270
1271
1272
1273
1274
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1275
1276
1277

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1278
1279
1280


if __name__ == "__main__":
1281
    _run()