registry.py 50.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
84
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
85
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
86
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
87
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
88
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
89
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
92
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
93
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
94
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
95
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
96
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
99
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
100
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
101
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
102
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
103
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
104
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
105
106
107
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
108
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
109
110
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
111
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
112
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
113
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
114
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
117
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
118
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
119
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
120
121
122
123
124
125
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
126
127
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
128
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
129
130
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
131
132
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
133
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
134
135
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
136
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
137
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
138
139
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
140
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
141
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
142
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
143
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
144
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
145
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
146
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
147
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
148
149
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
150
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
151
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
152
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
153
154
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
155
156
157
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
158
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
159
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
160
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
161
162
163
164
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
165
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
166
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
167
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
168
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
171
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
174
175
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
176
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
177
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
178
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
179
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
180
181
182
183
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
184
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
185
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
186
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
187
188
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
189
190
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
191
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
192
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
193
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
194
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
195
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
196
197
198
199
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
200
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
203
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
204
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
205
206
207
}

_EMBEDDING_MODELS = {
208
    # [Text-only]
209
    "BertModel": ("bert", "BertEmbeddingModel"),
210
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
211
    "HF_ColBERT": ("colbert", "ColBERTModel"),
212
213
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
214
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
215
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
216
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
217
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
218
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
219
    "GritLM": ("gritlm", "GritLM"),
220
221
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
222
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
223
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
224
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
225
    "LlamaModel": ("llama", "LlamaForCausalLM"),
226
227
    **{
        # Multiple models share the same architecture, so we include them all
228
229
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
230
231
        if arch == "LlamaForCausalLM"
    },
232
    "MistralModel": ("llama", "LlamaForCausalLM"),
233
    "ModernBertModel": ("modernbert", "ModernBertModel"),
234
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
235
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
236
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
237
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
239
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
240
241
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
242
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
243
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
244
245
246
247
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
248
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
249
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
250
    # [Multimodal]
251
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
252
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
253
254
255
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
256
    ),
257
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
258
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
259
260
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
261
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
262
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
263
264
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
265
    # models for the time being.
266
267
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
268
269
}

270
271
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
272
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
273
274
275
276
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
277
278
279
280
281
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
282
283
284
285
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
286
287
288
289
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
290
291
292
293
294
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
295
296
}

297
_MULTIMODAL_MODELS = {
298
    # [Decoder-only]
299
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
300
301
302
303
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
304
305
306
307
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
308
309
310
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
311
    ),
312
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
313
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
314
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
315
316
317
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
318
    ),
319
320
321
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
322
    ),
323
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
324
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
325
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
326
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
327
328
329
330
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
331
332
333
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
334
    ),
335
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
336
337
338
339
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
340
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
341
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
342
343
344
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
345
    ),
346
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
347
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
348
349
350
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
351
352
353
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
354
    ),
355
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
356
357
358
359
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
360
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
361
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
362
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
363
364
365
366
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
367
368
369
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
370
    ),
371
372
373
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
374
    ),
zxy's avatar
zxy committed
375
376
377
378
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
379
380
381
382
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
383
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
384
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
385
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
386
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
387
388
389
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
390
    ),
391
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
392
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
393
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
394
395
396
397
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
398
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
399
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
400
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
401
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
402
403
404
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
405
    ),
406
407
408
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
409
    ),
410
411
412
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
413
    ),
414
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
415
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
416
417
418
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
419
    ),
420
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
421
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
422
423
424
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
425
    ),
426
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
427
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
428
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
429
430
431
432
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
433
    "Ovis": ("ovis", "Ovis"),
434
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
435
436
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
437
438
439
440
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
441
442
443
444
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
445
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
446
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
447
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
448
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
449
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
450
451
452
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
453
    ),
454
455
456
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
457
    ),
458
459
460
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
461
    ),
462
463
464
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
465
    ),
466
467
468
469
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
470
471
472
473
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
474
475
476
477
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
478
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
479
480
481
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
482
    ),
483
484
485
486
487
488
489
490
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
491
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
492
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
493
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
494
495
496
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
497
    ),
498
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
499
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
500
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
501
    # [Encoder-decoder]
502
503
504
505
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
506
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
507
}
508
509

_SPECULATIVE_DECODING_MODELS = {
510
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
511
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
512
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
513
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
514
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
515
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
516
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
517
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
518
519
520
521
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
522
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
523
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
524
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
525
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
526
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
527
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
528
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
529
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
530
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
531
    "MedusaModel": ("medusa", "Medusa"),
532
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
533
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
534
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
535
536
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
537
538
539
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
540
}
541

542
_TRANSFORMERS_SUPPORTED_MODELS = {
543
544
545
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
546
547
548
549
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
550
551
552
}

_TRANSFORMERS_BACKEND_MODELS = {
553
    # Text generation models
554
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
573
    "TransformersForSequenceClassification": (
574
        "transformers",
575
        "TransformersForSequenceClassification",
576
    ),
577
    "TransformersMoEForSequenceClassification": (
578
        "transformers",
579
        "TransformersMoEForSequenceClassification",
580
    ),
581
582
583
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
584
    ),
585
}
586

587
_VLLM_MODELS = {
588
    **_TEXT_GENERATION_MODELS,
589
    **_EMBEDDING_MODELS,
590
    **_CROSS_ENCODER_MODELS,
591
    **_MULTIMODAL_MODELS,
592
    **_SPECULATIVE_DECODING_MODELS,
593
594
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
595
596
}

597
598
599
600
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
601
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
602

603
_PREVIOUSLY_SUPPORTED_MODELS = {
604
    "MotifForCausalLM": "0.10.2",
605
    "Phi3SmallForCausalLM": "0.9.2",
606
    "Phi4FlashForCausalLM": "0.10.2",
607
    "Phi4MultimodalForCausalLM": "0.12.0",
608
609
610
611
612
613
614
615
616
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
617

618

619
620
@dataclass(frozen=True)
class _ModelInfo:
621
    architecture: str
622
    is_text_generation_model: bool
623
    is_pooling_model: bool
624
    attn_type: AttnTypeStr
625
626
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
627
    supports_cross_encoding: bool
628
    supports_late_interaction: bool
629
    supports_multimodal: bool
630
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
631
    requires_raw_input_tokens: bool
632
    supports_multimodal_encoder_tp_data: bool
633
    supports_pp: bool
634
635
    has_inner_state: bool
    is_attention_free: bool
636
    is_hybrid: bool
637
    has_noops: bool
638
    supports_mamba_prefix_caching: bool
639
    supports_transcription: bool
640
    supports_transcription_only: bool
641
642

    @staticmethod
643
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
644
        return _ModelInfo(
645
            architecture=model.__name__,
646
            is_text_generation_model=is_text_generation_model(model),
647
            is_pooling_model=is_pooling_model(model),
648
649
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
650
            attn_type=get_attn_type(model),
651
            supports_cross_encoding=supports_cross_encoding(model),
652
            supports_late_interaction=supports_late_interaction(model),
653
            supports_multimodal=supports_multimodal(model),
654
655
656
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
657
            requires_raw_input_tokens=requires_raw_input_tokens(model),
658
659
660
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
661
            supports_pp=supports_pp(model),
662
663
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
664
            is_hybrid=is_hybrid(model),
665
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
666
            supports_transcription=supports_transcription(model),
667
668
669
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
670
            has_noops=has_noops(model),
671
        )
672
673


674
675
676
677
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
678

679
    @abstractmethod
680
    def load_model_cls(self) -> type[nn.Module]:
681
        raise NotImplementedError
682
683


684
685
686
687
688
689
690
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
691
    model_cls: type[nn.Module]
692
693

    @staticmethod
694
    def from_model_cls(model_cls: type[nn.Module]):
695
696
697
698
699
700
701
702
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

703
    def load_model_cls(self) -> type[nn.Module]:
704
705
706
707
708
709
710
711
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
712

713
714
715
    module_name: str
    class_name: str

716
717
718
719
720
721
722
723
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

724
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
725
726
        try:
            try:
727
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
728
729
730
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
731
                logger.debug(
732
                    "Cached model info file for class %s.%s not found",
733
734
735
                    self.module_name,
                    self.class_name,
                )
736
737
738
                return None

            if mi_dict["hash"] != module_hash:
739
                logger.debug(
740
                    "Cached model info file for class %s.%s is stale",
741
742
743
                    self.module_name,
                    self.class_name,
                )
744
745
746
747
748
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
749
            logger.debug(
750
                "Cached model info for class %s.%s error. ",
751
752
753
                self.module_name,
                self.class_name,
            )
754
755
            return None

756
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
757
758
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
759

760
761
762
763
764
765
766
767
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
768
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
769
770
771
772
773
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
774
    def inspect_model_cls(self) -> _ModelInfo:
775
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
776
        module_hash = None
777

778
779
        if model_path.exists():
            with open(model_path, "rb") as f:
780
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
781
782
783

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
784
                logger.debug(
785
                    "Loaded model info for class %s.%s from cache",
786
787
788
                    self.module_name,
                    self.class_name,
                )
789
790
                return mi
            else:
791
                logger.debug(
792
                    "Cache model info for class %s.%s miss. Loading model instead.",
793
794
795
                    self.module_name,
                    self.class_name,
                )
796
797
798

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
799
800
801
802
803
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
804
805

        # save cache file
806
807
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
808
809

        return mi
810

811
    def load_model_cls(self) -> type[nn.Module]:
812
813
814
815
816
817
818
819
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
820
) -> type[nn.Module] | None:
821
    from vllm.platforms import current_platform
822

823
    current_platform.verify_model_arch(model_arch)
824
825
826
    try:
        return model.load_model_cls()
    except Exception:
827
        logger.exception("Error in loading model architecture '%s'", model_arch)
828
        return None
829
830


831
832
833
834
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
835
) -> _ModelInfo | None:
836
837
838
    try:
        return model.inspect_model_cls()
    except Exception:
839
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
840
        return None
841
842


843
844
845
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
846
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
847

848
    def get_supported_archs(self) -> Set[str]:
849
        return self.models.keys()
850

851
852
853
    def register_model(
        self,
        model_arch: str,
854
        model_cls: type[nn.Module] | str,
855
    ) -> None:
856
857
858
        """
        Register an external model to be used in vLLM.

859
        `model_cls` can be either:
860

861
        - A [`torch.nn.Module`][] class directly referencing the model.
862
        - A string in the format `<module>:<class>` which can be used to
863
864
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
865
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
866
        """
867
868
869
870
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

871
        if model_arch in self.models:
872
873
            logger.warning(
                "Model architecture %s is already registered, and will be "
874
875
876
877
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
878
879
880
881
882
883

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
884

885
            model = _LazyRegisteredModel(*split_str)
886
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
887
            model = _RegisteredModel.from_model_cls(model_cls)
888
        else:
889
890
891
892
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
893
            raise TypeError(msg)
894

895
        self.models[model_arch] = model
896

897
    def _raise_for_unsupported(self, architectures: list[str]):
898
        all_supported_archs = self.get_supported_archs()
899

900
901
902
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
903
904
                "to be inspected. Please check the logs for more details."
            )
905

906
907
908
909
910
911
912
913
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
914
915
                    "use this model architecture."
                )
916

917
918
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
919
920
            f"Supported architectures: {all_supported_archs}"
        )
921

922
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
923
924
        if model_arch not in self.models:
            return None
925

926
        return _try_load_model_cls(model_arch, self.models[model_arch])
927

928
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
929
930
        if model_arch not in self.models:
            return None
931

932
933
934
935
936
937
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
938
    ) -> str | None:
939
940
941
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

942
943
944
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
961
                        trust_remote_code=model_config.trust_remote_code,
962
963
964
965
966
967
968
969
970
971
972
973
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
974
                        trust_remote_code=model_config.trust_remote_code,
975
976
977
978
979
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
980
                if model_config.model_impl != "transformers":
981
982
983
984
985
986
987
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
988
989
                    "'auto_map' (relevant if the model is custom)."
                )
990
991

        if not model_module.is_backend_compatible():
992
            if model_config.model_impl != "transformers":
993
                return None
994

995
996
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
997
998
                "is not compatible with vLLM."
            )
999

1000
        return model_config._get_transformers_backend_cls()
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1027

1028
1029
    def inspect_model_cls(
        self,
1030
        architectures: str | list[str],
1031
        model_config: ModelConfig,
1032
    ) -> tuple[_ModelInfo, str]:
1033
1034
        if isinstance(architectures, str):
            architectures = [architectures]
1035
1036
        if not architectures:
            raise ValueError("No model architectures are specified")
1037
1038

        # Require transformers impl
1039
        if model_config.model_impl == "transformers":
1040
            arch = self._try_resolve_transformers(architectures[0], model_config)
1041
1042
1043
1044
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1045
        elif model_config.model_impl == "terratorch":
1046
1047
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1048

1049
        # Fallback to transformers impl (after resolving convert_type)
1050
1051
1052
1053
1054
1055
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1056
1057
1058
1059
1060
1061
1062
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1063
            model_info = self._try_inspect_model_cls(normalized_arch)
1064
            if model_info is not None:
1065
                return (model_info, arch)
1066

1067
        # Fallback to transformers impl (before resolving runner_type)
1068
1069
1070
1071
1072
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1073
1074
1075
1076
1077
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1078
        return self._raise_for_unsupported(architectures)
1079

1080
1081
    def resolve_model_cls(
        self,
1082
        architectures: str | list[str],
1083
        model_config: ModelConfig,
1084
    ) -> tuple[type[nn.Module], str]:
1085
1086
        if isinstance(architectures, str):
            architectures = [architectures]
1087
1088
        if not architectures:
            raise ValueError("No model architectures are specified")
1089
1090

        # Require transformers impl
1091
        if model_config.model_impl == "transformers":
1092
            arch = self._try_resolve_transformers(architectures[0], model_config)
1093
1094
1095
1096
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1097
        elif model_config.model_impl == "terratorch":
1098
1099
1100
1101
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1102

1103
        # Fallback to transformers impl (after resolving convert_type)
1104
1105
1106
1107
1108
1109
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1110
1111
1112
1113
1114
1115
1116
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1117
            model_cls = self._try_load_model_cls(normalized_arch)
1118
1119
            if model_cls is not None:
                return (model_cls, arch)
1120

1121
        # Fallback to transformers impl (before resolving runner_type)
1122
1123
1124
1125
1126
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1127
1128
1129
1130
1131
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1132
        return self._raise_for_unsupported(architectures)
1133

1134
1135
    def is_text_generation_model(
        self,
1136
        architectures: str | list[str],
1137
        model_config: ModelConfig,
1138
    ) -> bool:
1139
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1140
        return model_cls.is_text_generation_model
1141

1142
    def is_pooling_model(
1143
        self,
1144
        architectures: str | list[str],
1145
        model_config: ModelConfig,
1146
    ) -> bool:
1147
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1148
        return model_cls.is_pooling_model
1149

1150
1151
    def is_cross_encoder_model(
        self,
1152
        architectures: str | list[str],
1153
        model_config: ModelConfig,
1154
    ) -> bool:
1155
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1156
        return model_cls.supports_cross_encoding
1157

1158
1159
    def is_multimodal_model(
        self,
1160
        architectures: str | list[str],
1161
        model_config: ModelConfig,
1162
    ) -> bool:
1163
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1164
        return model_cls.supports_multimodal
1165

1166
    def is_multimodal_raw_input_only_model(
1167
        self,
1168
        architectures: str | list[str],
1169
        model_config: ModelConfig,
1170
    ) -> bool:
1171
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1172
        return model_cls.supports_multimodal_raw_input_only
1173

1174
1175
    def is_pp_supported_model(
        self,
1176
        architectures: str | list[str],
1177
        model_config: ModelConfig,
1178
    ) -> bool:
1179
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1180
        return model_cls.supports_pp
1181

1182
1183
    def model_has_inner_state(
        self,
1184
        architectures: str | list[str],
1185
        model_config: ModelConfig,
1186
    ) -> bool:
1187
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1188
        return model_cls.has_inner_state
1189

1190
1191
    def is_attention_free_model(
        self,
1192
        architectures: str | list[str],
1193
        model_config: ModelConfig,
1194
    ) -> bool:
1195
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1196
        return model_cls.is_attention_free
1197

1198
1199
    def is_hybrid_model(
        self,
1200
        architectures: str | list[str],
1201
        model_config: ModelConfig,
1202
    ) -> bool:
1203
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1204
1205
        return model_cls.is_hybrid

1206
1207
    def is_noops_model(
        self,
1208
        architectures: str | list[str],
1209
        model_config: ModelConfig,
1210
    ) -> bool:
1211
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1212
1213
        return model_cls.has_noops

1214
1215
    def is_transcription_model(
        self,
1216
        architectures: str | list[str],
1217
        model_config: ModelConfig,
1218
    ) -> bool:
1219
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1220
1221
        return model_cls.supports_transcription

1222
1223
    def is_transcription_only_model(
        self,
1224
        architectures: str | list[str],
1225
        model_config: ModelConfig,
1226
    ) -> bool:
1227
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1228
1229
        return model_cls.supports_transcription_only

1230

1231
1232
1233
1234
1235
1236
1237
1238
1239
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1240
1241
1242
1243
1244

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1245
1246
1247
1248
1249
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1250
        # `cloudpickle` allows pickling lambda functions directly
1251
        import cloudpickle
1252

1253
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1254
1255
1256

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1257
1258
1259
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1260
1261
1262
1263
1264
1265

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1266
1267
1268
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1269

1270
        with open(output_filepath, "rb") as f:
1271
1272
1273
1274
1275
1276
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1277

1278
1279
1280
1281
1282
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1283
1284
1285

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1286
1287
1288


if __name__ == "__main__":
1289
    _run()