registry.py 53.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.tasks import ScoreType
34
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
35
from vllm.utils.hashing import safe_hash
36

37
38
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
39
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
40
41
else:
    AttnTypeStr = Any
42
43
    SequencePoolingType = Any
    TokenPoolingType = Any
44
45


46
47
48
49
50
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
51
    requires_raw_input_tokens,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
    get_score_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
113
    "Rnj1ForCausalLM": ("rnj1", "Rnj1ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
114
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
115
    "Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
116
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
117
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
118
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
119
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
120
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
121
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
122
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
123
124
125
126
127
128
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
129
130
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
131
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
132
133
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
134
135
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
136
    "HYV3ForCausalLM": ("hy_v3", "HYV3ForCausalLM"),
137
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
138
    "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"),
139
    "HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"),
140
141
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
142
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
143
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
144
145
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
146
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
147
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
148
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
149
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),
150
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
151
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
152
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
153
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
154
155
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
156
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
157
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
158
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
159
160
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
161
162
163
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
164
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
165
    "Ministral3ForCausalLM": ("mistral", "MistralForCausalLM"),
166
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
167
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
168
169
170
171
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
172
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
173
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
174
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
175
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
176
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
177
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
178
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
179
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
180
    "OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
181
182
183
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
184
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
185
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
186
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
187
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
188
    "Param2MoEForCausalLM": ("param2moe", "Param2MoEForCausalLM"),
189
190
191
192
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
193
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
194
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
195
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
196
197
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
198
199
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
200
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
201
202
    "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),
    "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),
203
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
204
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
205
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
206
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
207
208
209
210
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
211
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
212
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
213
    "TeleChat3ForCausalLM": ("llama", "LlamaForCausalLM"),
214
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
215
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
216
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
217
218
219
}

_EMBEDDING_MODELS = {
220
    # [Text-only]
221
    "BertModel": ("bert", "BertEmbeddingModel"),
222
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
223
    "ErnieModel": ("ernie", "ErnieEmbeddingModel"),
224
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
225
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
226
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
227
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
228
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
229
    "GritLM": ("gritlm", "GritLM"),
230
231
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
232
    "JinaEmbeddingsV5Model": ("jina", "JinaEmbeddingsV5Model"),
233
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
234
    "LlamaModel": ("llama", "LlamaForCausalLM"),
235
236
    **{
        # Multiple models share the same architecture, so we include them all
237
238
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
239
240
        if arch == "LlamaForCausalLM"
    },
241
    "MistralModel": ("llama", "LlamaForCausalLM"),
242
    "ModernBertModel": ("modernbert", "ModernBertModel"),
243
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
244
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
245
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
246
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
247
248
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
249
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
250
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
251
252
253
254
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
255
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
256
    # [Multimodal]
257
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
258
    "ColPaliForRetrieval": ("colpali", "ColPaliModel"),
259
    "LlamaNemotronVLModel": ("nemotron_vl", "LlamaNemotronVLForEmbedding"),
260
261
262
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
263
    ),
264
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
265
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
266
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
267
268
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
269
    # models for the time being.
270
271
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
272
273
}

274
275
276
277
278
_LATE_INTERACTION_MODELS = {
    # [Text-only]
    "HF_ColBERT": ("colbert", "ColBERTModel"),
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
279
    "ColBERTLfm2Model": ("colbert", "ColBERTLfm2Model"),
280
    "JinaForRanking": ("jina", "JinaForRanking"),
281
282
    # [Multimodal]
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
283
    "ColPaliForRetrieval": ("colpali", "ColPaliModel"),
284
285
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
286
    "ColQwen3_5": ("colqwen3_5", "ColQwen3_5Model"),
287
288
289
290
291
292
293
294
295
296
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}

_REWARD_MODELS = {
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
}

_TOKEN_CLASSIFICATION_MODELS = {
297
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
298
    "ErnieForTokenClassification": ("ernie", "ErnieForTokenClassification"),
299
300
301
302
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
303
304
305
306
    "Qwen3ASRForcedAlignerForTokenClassification": (
        "qwen3_asr_forced_aligner",
        "Qwen3ASRForcedAlignerForTokenClassification",
    ),
307
308
309
310
311
}

_SEQUENCE_CLASSIFICATION_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
312
    "ErnieForSequenceClassification": ("ernie", "ErnieForSequenceClassification"),
313
314
315
316
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
317
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),
318
319
320
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
321
    ),
322
323
324
325
326
327
328
329
330
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
331
332
333
334
335
336
    # [Multimodal]
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
    ),
337
338
}

339
_MULTIMODAL_MODELS = {
340
    # [Decoder-only]
341
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
342
343
344
345
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
346
347
348
349
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
350
351
352
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
353
    ),
354
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
355
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
356
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
357
358
359
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
360
    ),
361
362
    "Cheers": ("cheers", "CheersForConditionalGeneration"),
    "CheersForConditionalGeneration": ("cheers", "CheersForConditionalGeneration"),
363
364
365
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
366
    ),
367
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
368
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
369
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
370
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
371
372
373
374
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
375
376
377
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
378
    ),
Kyungmin Lee's avatar
Kyungmin Lee committed
379
380
381
382
    "Exaone4_5_ForConditionalGeneration": (
        "exaone4_5",
        "Exaone4_5_ForConditionalGeneration",
    ),  # noqa: E501
383
384
385
386
    "FireRedASR2ForConditionalGeneration": (
        "fireredasr2",
        "FireRedASR2ForConditionalGeneration",
    ),
387
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),
388
389
390
391
    "FireRedLIDForConditionalGeneration": (
        "fireredlid",
        "FireRedLIDForConditionalGeneration",
    ),
392
393
394
395
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
396
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
397
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),
398
399
400
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
401
    ),
402
    "Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
403
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
404
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
405
406
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
407
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),
408
409
410
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
411
    ),
412
413
414
415
    "Granite4VisionForConditionalGeneration": (
        "granite4_vision",
        "Granite4VisionForConditionalGeneration",
    ),
416
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
417
418
419
420
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
421
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
422
423
424
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
425
    ),
426
427
428
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
429
    ),
zxy's avatar
zxy committed
430
431
432
433
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
434
435
436
437
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
438
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
439
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
440
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
441
442
443
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
444
    ),
445
446
447
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),
    "MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"),
448
449
450
451
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
452
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
453
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),
454
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
455
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
456
457
458
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
459
    ),
460
461
462
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
463
    ),
464
465
466
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
467
    ),
468
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),
469
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
470
471
472
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
473
    ),
474
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
475
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
476
477
478
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
479
    ),
480
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
481
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
482
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
483
484
    "NemotronH_Nano_Omni_Reasoning_V3": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
    "NemotronH_Super_Omni_Reasoning_V3": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
485
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
486
    "OpenCUAForConditionalGeneration": ("opencua", "OpenCUAForConditionalGeneration"),
487
488
489
490
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
491
    "Ovis": ("ovis", "Ovis"),
492
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
493
494
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
495
496
497
498
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
499
500
501
502
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
503
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
504
    "Phi4ForCausalLMV": ("phi4siglip", "Phi4ForCausalLMV"),
505
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
506
507
508
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
509
510
511
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
512
    ),
513
514
515
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
516
    ),
517
518
519
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
520
    ),
521
522
523
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
524
    ),
525
526
527
528
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
529
530
531
532
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
533
534
    "Qwen3ASRRealtimeGeneration": ("qwen3_asr_realtime", "Qwen3ASRRealtimeGeneration"),
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
535
536
537
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
538
    ),
539
    "Qwen3_5ForConditionalGeneration": ("qwen3_5", "Qwen3_5ForConditionalGeneration"),
540
541
542
543
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
544
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
545
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
546
547
548
549
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),
550
551
552
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
553
    ),
554
    "UltravoxModel": ("ultravox", "UltravoxModel"),
555
556
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),
557
    # [Encoder-decoder]
558
    "CohereAsrForConditionalGeneration": (
Ekagra Ranjan's avatar
Ekagra Ranjan committed
559
        "cohere_asr",
560
        "CohereAsrForConditionalGeneration",
Ekagra Ranjan's avatar
Ekagra Ranjan committed
561
    ),
562
563
564
565
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
566
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),
567
}
568
569

_SPECULATIVE_DECODING_MODELS = {
570
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
571
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
572
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
573
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
574
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
575
    "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
576
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
577
    "Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
578
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
579
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
580
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
581
582
583
584
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
585
586
    "Eagle3DeepseekV2ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
    "Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
587
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
588
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
589
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
590
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
591
    "Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"),
592
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
593
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
594
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
595
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
596
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
597
    "MedusaModel": ("medusa", "Medusa"),
598
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
599
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
600
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
601
602
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
603
    "HYV3MTPModel": ("hy_v3_mtp", "HYV3MTP"),
604
605
606
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
607
}
608

609
_TRANSFORMERS_SUPPORTED_MODELS = {
610
611
612
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
613
614
615
616
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
617
618
619
}

_TRANSFORMERS_BACKEND_MODELS = {
620
    # Text generation models
621
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
640
    "TransformersForSequenceClassification": (
641
        "transformers",
642
        "TransformersForSequenceClassification",
643
    ),
644
    "TransformersMoEForSequenceClassification": (
645
        "transformers",
646
        "TransformersMoEForSequenceClassification",
647
    ),
648
649
650
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
651
    ),
652
}
653

654
_VLLM_MODELS = {
655
    **_TEXT_GENERATION_MODELS,
656
    **_EMBEDDING_MODELS,
657
658
659
660
    **_LATE_INTERACTION_MODELS,
    **_REWARD_MODELS,
    **_TOKEN_CLASSIFICATION_MODELS,
    **_SEQUENCE_CLASSIFICATION_MODELS,
661
    **_MULTIMODAL_MODELS,
662
    **_SPECULATIVE_DECODING_MODELS,
663
664
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
665
666
}

667
668
669
670
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
671
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
672

673
_PREVIOUSLY_SUPPORTED_MODELS = {
674
    "MotifForCausalLM": "0.10.2",
675
    "Phi3SmallForCausalLM": "0.9.2",
676
    "Phi4FlashForCausalLM": "0.10.2",
677
    "Phi4MultimodalForCausalLM": "0.12.0",
678
679
680
681
682
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "DonutForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
683

684
685
686
687
688
689
690
_OOT_SUPPORTED_MODELS = {
    "BartModel": "https://github.com/vllm-project/bart-plugin",
    "BartForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
    "Florence2ForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
    "MBartForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
}

691

692
693
@dataclass(frozen=True)
class _ModelInfo:
694
    architecture: str
695
    is_text_generation_model: bool
696
    is_pooling_model: bool
697
    attn_type: AttnTypeStr
698
699
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
700
    score_type: ScoreType
701
    supports_multimodal: bool
702
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
703
    requires_raw_input_tokens: bool
704
    supports_multimodal_encoder_tp_data: bool
705
    supports_pp: bool
706
707
    has_inner_state: bool
    is_attention_free: bool
708
    is_hybrid: bool
709
    has_noops: bool
710
    supports_mamba_prefix_caching: bool
711
    supports_transcription: bool
712
    supports_transcription_only: bool
713
714

    @staticmethod
715
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
716
        return _ModelInfo(
717
            architecture=model.__name__,
718
            is_text_generation_model=is_text_generation_model(model),
719
            is_pooling_model=is_pooling_model(model),
720
721
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
722
            attn_type=get_attn_type(model),
723
            score_type=get_score_type(model),
724
            supports_multimodal=supports_multimodal(model),
725
726
727
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
728
            requires_raw_input_tokens=requires_raw_input_tokens(model),
729
730
731
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
732
            supports_pp=supports_pp(model),
733
734
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
735
            is_hybrid=is_hybrid(model),
736
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
737
            supports_transcription=supports_transcription(model),
738
739
740
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
741
            has_noops=has_noops(model),
742
        )
743
744


745
746
747
748
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
749

750
    @abstractmethod
751
    def load_model_cls(self) -> type[nn.Module]:
752
        raise NotImplementedError
753
754


755
756
757
758
759
760
761
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
762
    model_cls: type[nn.Module]
763
764

    @staticmethod
765
    def from_model_cls(model_cls: type[nn.Module]):
766
767
768
769
770
771
772
773
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

774
    def load_model_cls(self) -> type[nn.Module]:
775
776
777
778
779
780
781
782
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
783

784
785
786
    module_name: str
    class_name: str

787
788
789
790
791
792
793
794
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

795
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
796
797
        try:
            try:
798
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
799
800
801
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
802
                logger.debug(
803
                    "Cached model info file for class %s.%s not found",
804
805
806
                    self.module_name,
                    self.class_name,
                )
807
808
809
                return None

            if mi_dict["hash"] != module_hash:
810
                logger.debug(
811
                    "Cached model info file for class %s.%s is stale",
812
813
814
                    self.module_name,
                    self.class_name,
                )
815
816
817
818
819
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
820
            logger.debug(
821
                "Cached model info for class %s.%s error. ",
822
823
824
                self.module_name,
                self.class_name,
            )
825
826
            return None

827
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
828
829
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
830

831
832
833
834
835
836
837
838
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
839
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
840
841
842
843
844
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
845
    def inspect_model_cls(self) -> _ModelInfo:
846
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
847
        module_hash = None
848

849
850
        if model_path.exists():
            with open(model_path, "rb") as f:
851
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
852
853
854

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
855
                logger.debug(
856
                    "Loaded model info for class %s.%s from cache",
857
858
859
                    self.module_name,
                    self.class_name,
                )
860
861
                return mi
            else:
862
                logger.debug(
863
                    "Cache model info for class %s.%s miss. Loading model instead.",
864
865
866
                    self.module_name,
                    self.class_name,
                )
867
868
869

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
870
871
872
873
874
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
875
876

        # save cache file
877
878
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
879
880

        return mi
881

882
    def load_model_cls(self) -> type[nn.Module]:
883
884
885
886
887
888
889
890
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
891
) -> type[nn.Module] | None:
892
    from vllm.platforms import current_platform
893

894
    current_platform.verify_model_arch(model_arch)
895
896
897
    try:
        return model.load_model_cls()
    except Exception:
898
        logger.exception("Error in loading model architecture '%s'", model_arch)
899
        return None
900
901


902
903
904
905
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
906
) -> _ModelInfo | None:
907
908
909
    try:
        return model.inspect_model_cls()
    except Exception:
910
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
911
        return None
912
913


914
915
916
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
917
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
918

919
    def get_supported_archs(self) -> Set[str]:
920
        return self.models.keys()
921

922
923
924
    def register_model(
        self,
        model_arch: str,
925
        model_cls: type[nn.Module] | str,
926
    ) -> None:
927
928
929
        """
        Register an external model to be used in vLLM.

930
        `model_cls` can be either:
931

932
        - A [`torch.nn.Module`][] class directly referencing the model.
933
        - A string in the format `<module>:<class>` which can be used to
934
935
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
936
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
937
        """
938
939
940
941
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

942
        if model_arch in self.models:
943
944
            logger.warning(
                "Model architecture %s is already registered, and will be "
945
946
947
948
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
949
950
951
952
953
954

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
955

956
            model = _LazyRegisteredModel(*split_str)
957
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
958
            model = _RegisteredModel.from_model_cls(model_cls)
959
        else:
960
961
962
963
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
964
            raise TypeError(msg)
965

966
        self.models[model_arch] = model
967

968
    def _raise_for_unsupported(self, architectures: list[str]):
969
        all_supported_archs = self.get_supported_archs()
970

971
972
973
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
974
975
                "to be inspected. Please check the logs for more details."
            )
976

977
978
979
980
981
982
983
984
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
985
986
                    "use this model architecture."
                )
987
988
989
990
991
992
993
994
            if arch in _OOT_SUPPORTED_MODELS:
                plugin_url = _OOT_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} is not supported in-tree anymore. "
                    f"Please install the plugin at {plugin_url} if you want to "
                    "use this model architecture."
                )
995

996
997
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
998
999
            f"Supported architectures: {all_supported_archs}"
        )
1000

1001
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
1002
1003
        if model_arch not in self.models:
            return None
1004

1005
        return _try_load_model_cls(model_arch, self.models[model_arch])
1006

1007
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
1008
1009
        if model_arch not in self.models:
            return None
1010

1011
1012
1013
1014
1015
1016
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
1017
    ) -> str | None:
1018
1019
1020
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

1021
1022
1023
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
1040
                        trust_remote_code=model_config.trust_remote_code,
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
1053
                        trust_remote_code=model_config.trust_remote_code,
1054
1055
1056
1057
1058
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
1059
                if model_config.model_impl != "transformers":
1060
1061
1062
1063
1064
1065
1066
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
1067
1068
                    "'auto_map' (relevant if the model is custom)."
                )
1069
1070

        if not model_module.is_backend_compatible():
1071
            if model_config.model_impl != "transformers":
1072
                return None
1073

1074
1075
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1076
1077
                "is not compatible with vLLM."
            )
1078

1079
        return model_config._get_transformers_backend_cls()
1080

1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1106

1107
1108
    def inspect_model_cls(
        self,
1109
        architectures: str | list[str],
1110
        model_config: ModelConfig,
1111
    ) -> tuple[_ModelInfo, str]:
1112
1113
        if isinstance(architectures, str):
            architectures = [architectures]
1114
1115
        if not architectures:
            raise ValueError("No model architectures are specified")
1116
1117

        # Require transformers impl
1118
        if model_config.model_impl == "transformers":
1119
            arch = self._try_resolve_transformers(architectures[0], model_config)
1120
1121
1122
1123
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1124
        elif model_config.model_impl == "terratorch":
1125
1126
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1127

1128
        # Fallback to transformers impl (after resolving convert_type)
1129
1130
1131
1132
1133
1134
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1135
1136
1137
1138
1139
1140
1141
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1142
            model_info = self._try_inspect_model_cls(normalized_arch)
1143
            if model_info is not None:
1144
                return (model_info, arch)
1145

1146
        # Fallback to transformers impl (before resolving runner_type)
1147
1148
1149
1150
1151
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1152
1153
1154
1155
1156
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1157
        return self._raise_for_unsupported(architectures)
1158

1159
1160
    def resolve_model_cls(
        self,
1161
        architectures: str | list[str],
1162
        model_config: ModelConfig,
1163
    ) -> tuple[type[nn.Module], str]:
1164
1165
        if isinstance(architectures, str):
            architectures = [architectures]
1166
1167
        if not architectures:
            raise ValueError("No model architectures are specified")
1168
1169

        # Require transformers impl
1170
        if model_config.model_impl == "transformers":
1171
            arch = self._try_resolve_transformers(architectures[0], model_config)
1172
1173
1174
1175
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1176
        elif model_config.model_impl == "terratorch":
1177
1178
1179
1180
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1181

1182
        # Fallback to transformers impl (after resolving convert_type)
1183
1184
1185
1186
1187
1188
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1189
1190
1191
1192
1193
1194
1195
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1196
            model_cls = self._try_load_model_cls(normalized_arch)
1197
1198
            if model_cls is not None:
                return (model_cls, arch)
1199

1200
        # Fallback to transformers impl (before resolving runner_type)
1201
1202
1203
1204
1205
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1206
1207
1208
1209
1210
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1211
        return self._raise_for_unsupported(architectures)
1212

1213
1214
    def is_text_generation_model(
        self,
1215
        architectures: str | list[str],
1216
        model_config: ModelConfig,
1217
    ) -> bool:
1218
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1219
        return model_cls.is_text_generation_model
1220

1221
    def is_pooling_model(
1222
        self,
1223
        architectures: str | list[str],
1224
        model_config: ModelConfig,
1225
    ) -> bool:
1226
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1227
        return model_cls.is_pooling_model
1228
1229
1230

    def is_multimodal_model(
        self,
1231
        architectures: str | list[str],
1232
        model_config: ModelConfig,
1233
    ) -> bool:
1234
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1235
        return model_cls.supports_multimodal
1236

1237
    def is_multimodal_raw_input_only_model(
1238
        self,
1239
        architectures: str | list[str],
1240
        model_config: ModelConfig,
1241
    ) -> bool:
1242
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1243
        return model_cls.supports_multimodal_raw_input_only
1244

1245
1246
    def is_pp_supported_model(
        self,
1247
        architectures: str | list[str],
1248
        model_config: ModelConfig,
1249
    ) -> bool:
1250
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1251
        return model_cls.supports_pp
1252

1253
1254
    def model_has_inner_state(
        self,
1255
        architectures: str | list[str],
1256
        model_config: ModelConfig,
1257
    ) -> bool:
1258
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1259
        return model_cls.has_inner_state
1260

1261
1262
    def is_attention_free_model(
        self,
1263
        architectures: str | list[str],
1264
        model_config: ModelConfig,
1265
    ) -> bool:
1266
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1267
        return model_cls.is_attention_free
1268

1269
1270
    def is_hybrid_model(
        self,
1271
        architectures: str | list[str],
1272
        model_config: ModelConfig,
1273
    ) -> bool:
1274
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1275
1276
        return model_cls.is_hybrid

1277
1278
    def is_noops_model(
        self,
1279
        architectures: str | list[str],
1280
        model_config: ModelConfig,
1281
    ) -> bool:
1282
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1283
1284
        return model_cls.has_noops

1285
1286
    def is_transcription_model(
        self,
1287
        architectures: str | list[str],
1288
        model_config: ModelConfig,
1289
    ) -> bool:
1290
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1291
1292
        return model_cls.supports_transcription

1293
1294
    def is_transcription_only_model(
        self,
1295
        architectures: str | list[str],
1296
        model_config: ModelConfig,
1297
    ) -> bool:
1298
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1299
1300
        return model_cls.supports_transcription_only

1301

1302
1303
1304
1305
1306
1307
1308
1309
1310
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1311
1312
1313
1314
1315

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1316
1317
1318
1319
1320
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1321
        # `cloudpickle` allows pickling lambda functions directly
1322
        import cloudpickle
1323

1324
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1325
1326
1327

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1328
1329
1330
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1331
1332
1333
1334
1335
1336

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1337
1338
1339
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1340

1341
        with open(output_filepath, "rb") as f:
1342
1343
1344
1345
1346
1347
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1348

1349
1350
1351
1352
1353
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1354
1355
1356

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1357
1358
1359


if __name__ == "__main__":
1360
    _run()