registry.py 51.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.tasks import ScoreType
34
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
35
from vllm.utils.hashing import safe_hash
36

37
38
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
39
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
40
41
else:
    AttnTypeStr = Any
42
43
    SequencePoolingType = Any
    TokenPoolingType = Any
44
45


46
47
48
49
50
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
51
    requires_raw_input_tokens,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
    get_score_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
113
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
114
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
115
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
117
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
118
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
119
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
120
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
121
122
123
124
125
126
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
127
128
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
129
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
130
131
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
132
133
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
134
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
135
136
    "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"),
    "HyperCLOVAXForCausalLM": ("llama", "LlamaForCausalLM"),
137
138
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
139
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
140
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
141
142
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
143
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
144
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
145
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
146
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
147
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
148
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
149
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
150
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
151
152
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
153
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
154
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
155
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
156
157
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
158
159
160
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
161
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
162
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
163
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
164
165
166
167
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
168
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
169
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
170
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
171
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
172
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
173
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
174
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
175
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
176
    "OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
177
178
179
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
180
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
181
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
182
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
183
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
184
185
186
187
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
188
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
189
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
190
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
191
192
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
193
194
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
195
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
196
197
    "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),
    "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),
198
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
199
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
200
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
201
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
202
203
204
205
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
206
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
207
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
208
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
209
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
210
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
211
212
213
}

_EMBEDDING_MODELS = {
214
    # [Text-only]
215
    "BertModel": ("bert", "BertEmbeddingModel"),
216
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
217
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
218
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
219
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
220
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
221
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
222
    "GritLM": ("gritlm", "GritLM"),
223
224
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
225
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
226
    "LlamaModel": ("llama", "LlamaForCausalLM"),
227
228
    **{
        # Multiple models share the same architecture, so we include them all
229
230
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
231
232
        if arch == "LlamaForCausalLM"
    },
233
    "MistralModel": ("llama", "LlamaForCausalLM"),
234
    "ModernBertModel": ("modernbert", "ModernBertModel"),
235
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
236
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
237
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
239
240
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
241
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
242
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
243
244
245
246
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
247
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
248
    # [Multimodal]
249
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
250
251
252
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
253
    ),
254
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
255
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
256
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
257
258
259
260
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
261
262
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
263
    # models for the time being.
264
265
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
266
267
}

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
_LATE_INTERACTION_MODELS = {
    # [Text-only]
    "HF_ColBERT": ("colbert", "ColBERTModel"),
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
    # [Multimodal]
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}

_REWARD_MODELS = {
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
}

_TOKEN_CLASSIFICATION_MODELS = {
287
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
288
289
290
291
292
293
294
295
296
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
}

_SEQUENCE_CLASSIFICATION_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
297
298
299
300
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
301
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
302
303
304
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
305
    ),
306
307
308
309
310
311
312
313
314
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
315
316
317
318
319
320
    # [Multimodal]
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
    ),
321
322
}

323
_MULTIMODAL_MODELS = {
324
    # [Decoder-only]
325
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
326
327
328
329
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
330
331
332
333
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
334
335
336
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
337
    ),
338
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
339
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
340
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
341
342
343
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
344
    ),
345
346
347
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
348
    ),
349
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
350
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
351
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
352
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
353
354
355
356
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
357
358
359
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
360
    ),
361
362
363
364
    "FireRedASR2ForConditionalGeneration": (
        "fireredasr2",
        "FireRedASR2ForConditionalGeneration",
    ),
365
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
366
367
368
369
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
370
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
371
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
372
373
374
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
375
    ),
376
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
377
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
378
379
380
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
381
382
383
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
384
    ),
385
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
386
387
388
389
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
390
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
391
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
392
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
393
394
395
396
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
397
398
399
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
400
    ),
401
402
403
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
404
    ),
zxy's avatar
zxy committed
405
406
407
408
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
409
410
411
412
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
413
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
414
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
415
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
416
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
417
418
419
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
420
    ),
421
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
422
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
423
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
424
    "MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"),  # noqa: E501
425
426
427
428
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
429
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
430
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
431
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
432
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
433
434
435
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
436
    ),
437
438
439
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
440
    ),
441
442
443
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
444
    ),
445
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
446
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
447
448
449
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
450
    ),
451
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
452
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
453
454
455
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
456
    ),
457
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
458
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
459
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
460
461
462
463
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
464
    "Ovis": ("ovis", "Ovis"),
465
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
466
467
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
468
469
470
471
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
472
473
474
475
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
476
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
477
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
478
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
479
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
480
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
481
482
483
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
484
    ),
485
486
487
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
488
    ),
489
490
491
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
492
    ),
493
494
495
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
496
    ),
497
498
499
500
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
501
502
503
504
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
505
506
507
508
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
509
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
510
511
512
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
513
    ),
514
515
516
517
518
519
520
521
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
522
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
523
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
524
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
525
526
527
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
528
    ),
529
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
530
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
531
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
532
    # [Encoder-decoder]
533
534
535
536
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
537
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
538
}
539
540

_SPECULATIVE_DECODING_MODELS = {
541
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
542
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
543
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
544
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
545
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
546
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
547
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
548
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
549
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
550
551
552
553
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
554
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
555
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
556
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
557
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
558
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
559
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
560
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
561
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
562
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
563
    "MedusaModel": ("medusa", "Medusa"),
564
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
565
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
566
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
567
568
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
569
570
571
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
572
}
573

574
_TRANSFORMERS_SUPPORTED_MODELS = {
575
576
577
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
578
579
580
581
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
582
583
584
}

_TRANSFORMERS_BACKEND_MODELS = {
585
    # Text generation models
586
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
605
    "TransformersForSequenceClassification": (
606
        "transformers",
607
        "TransformersForSequenceClassification",
608
    ),
609
    "TransformersMoEForSequenceClassification": (
610
        "transformers",
611
        "TransformersMoEForSequenceClassification",
612
    ),
613
614
615
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
616
    ),
617
}
618

619
_VLLM_MODELS = {
620
    **_TEXT_GENERATION_MODELS,
621
    **_EMBEDDING_MODELS,
622
623
624
625
    **_LATE_INTERACTION_MODELS,
    **_REWARD_MODELS,
    **_TOKEN_CLASSIFICATION_MODELS,
    **_SEQUENCE_CLASSIFICATION_MODELS,
626
    **_MULTIMODAL_MODELS,
627
    **_SPECULATIVE_DECODING_MODELS,
628
629
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
630
631
}

632
633
634
635
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
636
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
637

638
_PREVIOUSLY_SUPPORTED_MODELS = {
639
    "MotifForCausalLM": "0.10.2",
640
    "Phi3SmallForCausalLM": "0.9.2",
641
    "Phi4FlashForCausalLM": "0.10.2",
642
    "Phi4MultimodalForCausalLM": "0.12.0",
643
644
645
646
647
648
649
650
651
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
652

653

654
655
@dataclass(frozen=True)
class _ModelInfo:
656
    architecture: str
657
    is_text_generation_model: bool
658
    is_pooling_model: bool
659
    attn_type: AttnTypeStr
660
661
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
662
    score_type: ScoreType
663
    supports_multimodal: bool
664
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
665
    requires_raw_input_tokens: bool
666
    supports_multimodal_encoder_tp_data: bool
667
    supports_pp: bool
668
669
    has_inner_state: bool
    is_attention_free: bool
670
    is_hybrid: bool
671
    has_noops: bool
672
    supports_mamba_prefix_caching: bool
673
    supports_transcription: bool
674
    supports_transcription_only: bool
675
676

    @staticmethod
677
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
678
        return _ModelInfo(
679
            architecture=model.__name__,
680
            is_text_generation_model=is_text_generation_model(model),
681
            is_pooling_model=is_pooling_model(model),
682
683
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
684
            attn_type=get_attn_type(model),
685
            score_type=get_score_type(model),
686
            supports_multimodal=supports_multimodal(model),
687
688
689
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
690
            requires_raw_input_tokens=requires_raw_input_tokens(model),
691
692
693
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
694
            supports_pp=supports_pp(model),
695
696
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
697
            is_hybrid=is_hybrid(model),
698
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
699
            supports_transcription=supports_transcription(model),
700
701
702
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
703
            has_noops=has_noops(model),
704
        )
705
706


707
708
709
710
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
711

712
    @abstractmethod
713
    def load_model_cls(self) -> type[nn.Module]:
714
        raise NotImplementedError
715
716


717
718
719
720
721
722
723
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
724
    model_cls: type[nn.Module]
725
726

    @staticmethod
727
    def from_model_cls(model_cls: type[nn.Module]):
728
729
730
731
732
733
734
735
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

736
    def load_model_cls(self) -> type[nn.Module]:
737
738
739
740
741
742
743
744
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
745

746
747
748
    module_name: str
    class_name: str

749
750
751
752
753
754
755
756
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

757
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
758
759
        try:
            try:
760
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
761
762
763
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
764
                logger.debug(
765
                    "Cached model info file for class %s.%s not found",
766
767
768
                    self.module_name,
                    self.class_name,
                )
769
770
771
                return None

            if mi_dict["hash"] != module_hash:
772
                logger.debug(
773
                    "Cached model info file for class %s.%s is stale",
774
775
776
                    self.module_name,
                    self.class_name,
                )
777
778
779
780
781
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
782
            logger.debug(
783
                "Cached model info for class %s.%s error. ",
784
785
786
                self.module_name,
                self.class_name,
            )
787
788
            return None

789
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
790
791
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
792

793
794
795
796
797
798
799
800
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
801
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
802
803
804
805
806
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
807
    def inspect_model_cls(self) -> _ModelInfo:
808
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
809
        module_hash = None
810

811
812
        if model_path.exists():
            with open(model_path, "rb") as f:
813
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
814
815
816

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
817
                logger.debug(
818
                    "Loaded model info for class %s.%s from cache",
819
820
821
                    self.module_name,
                    self.class_name,
                )
822
823
                return mi
            else:
824
                logger.debug(
825
                    "Cache model info for class %s.%s miss. Loading model instead.",
826
827
828
                    self.module_name,
                    self.class_name,
                )
829
830
831

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
832
833
834
835
836
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
837
838

        # save cache file
839
840
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
841
842

        return mi
843

844
    def load_model_cls(self) -> type[nn.Module]:
845
846
847
848
849
850
851
852
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
853
) -> type[nn.Module] | None:
854
    from vllm.platforms import current_platform
855

856
    current_platform.verify_model_arch(model_arch)
857
858
859
    try:
        return model.load_model_cls()
    except Exception:
860
        logger.exception("Error in loading model architecture '%s'", model_arch)
861
        return None
862
863


864
865
866
867
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
868
) -> _ModelInfo | None:
869
870
871
    try:
        return model.inspect_model_cls()
    except Exception:
872
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
873
        return None
874
875


876
877
878
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
879
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
880

881
    def get_supported_archs(self) -> Set[str]:
882
        return self.models.keys()
883

884
885
886
    def register_model(
        self,
        model_arch: str,
887
        model_cls: type[nn.Module] | str,
888
    ) -> None:
889
890
891
        """
        Register an external model to be used in vLLM.

892
        `model_cls` can be either:
893

894
        - A [`torch.nn.Module`][] class directly referencing the model.
895
        - A string in the format `<module>:<class>` which can be used to
896
897
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
898
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
899
        """
900
901
902
903
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

904
        if model_arch in self.models:
905
906
            logger.warning(
                "Model architecture %s is already registered, and will be "
907
908
909
910
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
911
912
913
914
915
916

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
917

918
            model = _LazyRegisteredModel(*split_str)
919
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
920
            model = _RegisteredModel.from_model_cls(model_cls)
921
        else:
922
923
924
925
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
926
            raise TypeError(msg)
927

928
        self.models[model_arch] = model
929

930
    def _raise_for_unsupported(self, architectures: list[str]):
931
        all_supported_archs = self.get_supported_archs()
932

933
934
935
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
936
937
                "to be inspected. Please check the logs for more details."
            )
938

939
940
941
942
943
944
945
946
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
947
948
                    "use this model architecture."
                )
949

950
951
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
952
953
            f"Supported architectures: {all_supported_archs}"
        )
954

955
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
956
957
        if model_arch not in self.models:
            return None
958

959
        return _try_load_model_cls(model_arch, self.models[model_arch])
960

961
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
962
963
        if model_arch not in self.models:
            return None
964

965
966
967
968
969
970
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
971
    ) -> str | None:
972
973
974
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

975
976
977
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
994
                        trust_remote_code=model_config.trust_remote_code,
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
1007
                        trust_remote_code=model_config.trust_remote_code,
1008
1009
1010
1011
1012
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
1013
                if model_config.model_impl != "transformers":
1014
1015
1016
1017
1018
1019
1020
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
1021
1022
                    "'auto_map' (relevant if the model is custom)."
                )
1023
1024

        if not model_module.is_backend_compatible():
1025
            if model_config.model_impl != "transformers":
1026
                return None
1027

1028
1029
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1030
1031
                "is not compatible with vLLM."
            )
1032

1033
        return model_config._get_transformers_backend_cls()
1034

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1060

1061
1062
    def inspect_model_cls(
        self,
1063
        architectures: str | list[str],
1064
        model_config: ModelConfig,
1065
    ) -> tuple[_ModelInfo, str]:
1066
1067
        if isinstance(architectures, str):
            architectures = [architectures]
1068
1069
        if not architectures:
            raise ValueError("No model architectures are specified")
1070
1071

        # Require transformers impl
1072
        if model_config.model_impl == "transformers":
1073
            arch = self._try_resolve_transformers(architectures[0], model_config)
1074
1075
1076
1077
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1078
        elif model_config.model_impl == "terratorch":
1079
1080
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1081

1082
        # Fallback to transformers impl (after resolving convert_type)
1083
1084
1085
1086
1087
1088
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1089
1090
1091
1092
1093
1094
1095
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1096
            model_info = self._try_inspect_model_cls(normalized_arch)
1097
            if model_info is not None:
1098
                return (model_info, arch)
1099

1100
        # Fallback to transformers impl (before resolving runner_type)
1101
1102
1103
1104
1105
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1106
1107
1108
1109
1110
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1111
        return self._raise_for_unsupported(architectures)
1112

1113
1114
    def resolve_model_cls(
        self,
1115
        architectures: str | list[str],
1116
        model_config: ModelConfig,
1117
    ) -> tuple[type[nn.Module], str]:
1118
1119
        if isinstance(architectures, str):
            architectures = [architectures]
1120
1121
        if not architectures:
            raise ValueError("No model architectures are specified")
1122
1123

        # Require transformers impl
1124
        if model_config.model_impl == "transformers":
1125
            arch = self._try_resolve_transformers(architectures[0], model_config)
1126
1127
1128
1129
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1130
        elif model_config.model_impl == "terratorch":
1131
1132
1133
1134
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1135

1136
        # Fallback to transformers impl (after resolving convert_type)
1137
1138
1139
1140
1141
1142
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1143
1144
1145
1146
1147
1148
1149
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1150
            model_cls = self._try_load_model_cls(normalized_arch)
1151
1152
            if model_cls is not None:
                return (model_cls, arch)
1153

1154
        # Fallback to transformers impl (before resolving runner_type)
1155
1156
1157
1158
1159
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1160
1161
1162
1163
1164
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1165
        return self._raise_for_unsupported(architectures)
1166

1167
1168
    def is_text_generation_model(
        self,
1169
        architectures: str | list[str],
1170
        model_config: ModelConfig,
1171
    ) -> bool:
1172
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1173
        return model_cls.is_text_generation_model
1174

1175
    def is_pooling_model(
1176
        self,
1177
        architectures: str | list[str],
1178
        model_config: ModelConfig,
1179
    ) -> bool:
1180
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1181
        return model_cls.is_pooling_model
1182
1183
1184

    def is_multimodal_model(
        self,
1185
        architectures: str | list[str],
1186
        model_config: ModelConfig,
1187
    ) -> bool:
1188
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1189
        return model_cls.supports_multimodal
1190

1191
    def is_multimodal_raw_input_only_model(
1192
        self,
1193
        architectures: str | list[str],
1194
        model_config: ModelConfig,
1195
    ) -> bool:
1196
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1197
        return model_cls.supports_multimodal_raw_input_only
1198

1199
1200
    def is_pp_supported_model(
        self,
1201
        architectures: str | list[str],
1202
        model_config: ModelConfig,
1203
    ) -> bool:
1204
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1205
        return model_cls.supports_pp
1206

1207
1208
    def model_has_inner_state(
        self,
1209
        architectures: str | list[str],
1210
        model_config: ModelConfig,
1211
    ) -> bool:
1212
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1213
        return model_cls.has_inner_state
1214

1215
1216
    def is_attention_free_model(
        self,
1217
        architectures: str | list[str],
1218
        model_config: ModelConfig,
1219
    ) -> bool:
1220
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1221
        return model_cls.is_attention_free
1222

1223
1224
    def is_hybrid_model(
        self,
1225
        architectures: str | list[str],
1226
        model_config: ModelConfig,
1227
    ) -> bool:
1228
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1229
1230
        return model_cls.is_hybrid

1231
1232
    def is_noops_model(
        self,
1233
        architectures: str | list[str],
1234
        model_config: ModelConfig,
1235
    ) -> bool:
1236
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1237
1238
        return model_cls.has_noops

1239
1240
    def is_transcription_model(
        self,
1241
        architectures: str | list[str],
1242
        model_config: ModelConfig,
1243
    ) -> bool:
1244
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1245
1246
        return model_cls.supports_transcription

1247
1248
    def is_transcription_only_model(
        self,
1249
        architectures: str | list[str],
1250
        model_config: ModelConfig,
1251
    ) -> bool:
1252
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1253
1254
        return model_cls.supports_transcription_only

1255

1256
1257
1258
1259
1260
1261
1262
1263
1264
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1265
1266
1267
1268
1269

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1270
1271
1272
1273
1274
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1275
        # `cloudpickle` allows pickling lambda functions directly
1276
        import cloudpickle
1277

1278
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1279
1280
1281

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1282
1283
1284
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1285
1286
1287
1288
1289
1290

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1291
1292
1293
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1294

1295
        with open(output_filepath, "rb") as f:
1296
1297
1298
1299
1300
1301
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1302

1303
1304
1305
1306
1307
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1308
1309
1310

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1311
1312
1313


if __name__ == "__main__":
1314
    _run()