registry.py 53.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.tasks import ScoreType
34
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
35
from vllm.utils.hashing import safe_hash
36

37
38
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
39
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
40
41
else:
    AttnTypeStr = Any
42
43
    SequencePoolingType = Any
    TokenPoolingType = Any
44
45


46
47
48
49
50
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
51
    requires_raw_input_tokens,
52
    supports_mamba_prefix_caching,
53
54
55
56
57
58
59
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
60
    get_attn_type,
61
62
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
63
    get_score_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
113
    "Rnj1ForCausalLM": ("rnj1", "Rnj1ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
114
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
115
    "Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
116
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
117
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
118
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
119
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
120
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
121
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
122
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
123
124
125
126
127
128
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
129
130
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
131
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
132
133
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
134
135
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
136
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
137
    "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"),
138
    "HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"),
139
140
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
141
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
142
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
143
144
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
145
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
146
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
147
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
148
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),
149
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
150
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
151
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
152
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
153
154
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
155
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
156
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
157
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
158
159
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
160
161
162
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
163
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
164
    "Ministral3ForCausalLM": ("mistral", "MistralForCausalLM"),
165
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
166
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
167
168
169
170
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
171
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
172
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
173
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
174
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
175
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
176
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
177
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
178
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
179
    "OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
180
181
182
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
183
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
184
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
185
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
186
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
187
    "Param2MoEForCausalLM": ("param2moe", "Param2MoEForCausalLM"),
188
189
190
191
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
192
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
193
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
194
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
195
196
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
197
198
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
199
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
200
201
    "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),
    "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),
202
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
203
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
204
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
205
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
206
207
208
209
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
210
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
211
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
212
    "TeleChat3ForCausalLM": ("llama", "LlamaForCausalLM"),
213
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
214
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
215
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
216
217
218
}

_EMBEDDING_MODELS = {
219
    # [Text-only]
220
    "BertModel": ("bert", "BertEmbeddingModel"),
221
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
222
    "ErnieModel": ("ernie", "ErnieEmbeddingModel"),
223
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
224
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
225
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
226
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
227
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
228
    "GritLM": ("gritlm", "GritLM"),
229
230
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
231
    "JinaEmbeddingsV5Model": ("jina", "JinaEmbeddingsV5Model"),
232
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
233
    "LlamaModel": ("llama", "LlamaForCausalLM"),
234
235
    **{
        # Multiple models share the same architecture, so we include them all
236
237
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
238
239
        if arch == "LlamaForCausalLM"
    },
240
    "MistralModel": ("llama", "LlamaForCausalLM"),
241
    "ModernBertModel": ("modernbert", "ModernBertModel"),
242
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
243
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
244
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
245
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
246
247
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
248
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
249
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
250
251
252
253
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
254
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
255
    # [Multimodal]
256
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
257
    "ColPaliForRetrieval": ("colpali", "ColPaliModel"),
258
    "LlamaNemotronVLModel": ("nemotron_vl", "LlamaNemotronVLForEmbedding"),
259
260
261
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
262
    ),
263
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
264
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
265
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
266
267
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
268
    # models for the time being.
269
270
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
271
272
}

273
274
275
276
277
_LATE_INTERACTION_MODELS = {
    # [Text-only]
    "HF_ColBERT": ("colbert", "ColBERTModel"),
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
278
    "ColBERTLfm2Model": ("colbert", "ColBERTLfm2Model"),
279
    "JinaForRanking": ("jina", "JinaForRanking"),
280
281
    # [Multimodal]
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
282
    "ColPaliForRetrieval": ("colpali", "ColPaliModel"),
283
284
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
285
    "ColQwen3_5": ("colqwen3_5", "ColQwen3_5Model"),
286
287
288
289
290
291
292
293
294
295
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}

_REWARD_MODELS = {
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
}

_TOKEN_CLASSIFICATION_MODELS = {
296
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
297
    "ErnieForTokenClassification": ("ernie", "ErnieForTokenClassification"),
298
299
300
301
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
302
303
304
305
    "Qwen3ASRForcedAlignerForTokenClassification": (
        "qwen3_asr_forced_aligner",
        "Qwen3ASRForcedAlignerForTokenClassification",
    ),
306
307
308
309
310
}

_SEQUENCE_CLASSIFICATION_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
311
    "ErnieForSequenceClassification": ("ernie", "ErnieForSequenceClassification"),
312
313
314
315
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
316
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),
317
318
319
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
320
    ),
321
322
323
324
325
326
327
328
329
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
330
331
332
333
334
335
    # [Multimodal]
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
    ),
336
337
}

338
_MULTIMODAL_MODELS = {
339
    # [Decoder-only]
340
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
341
342
343
344
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
345
346
347
348
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
349
350
351
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
352
    ),
353
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
354
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
355
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
356
357
358
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
359
    ),
360
361
    "Cheers": ("cheers", "CheersForConditionalGeneration"),
    "CheersForConditionalGeneration": ("cheers", "CheersForConditionalGeneration"),
362
363
364
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
365
    ),
366
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
367
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
368
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
369
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
370
371
372
373
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
374
375
376
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
377
    ),
Kyungmin Lee's avatar
Kyungmin Lee committed
378
379
380
381
    "Exaone4_5_ForConditionalGeneration": (
        "exaone4_5",
        "Exaone4_5_ForConditionalGeneration",
    ),  # noqa: E501
382
383
384
385
    "FireRedASR2ForConditionalGeneration": (
        "fireredasr2",
        "FireRedASR2ForConditionalGeneration",
    ),
386
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),
387
388
389
390
    "FireRedLIDForConditionalGeneration": (
        "fireredlid",
        "FireRedLIDForConditionalGeneration",
    ),
391
392
393
394
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
395
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
396
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),
397
398
399
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
400
    ),
401
    "Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
402
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
403
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
404
405
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
406
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),
407
408
409
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
410
    ),
411
412
413
414
    "Granite4VisionForConditionalGeneration": (
        "granite4_vision",
        "Granite4VisionForConditionalGeneration",
    ),
415
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
416
417
418
419
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
420
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
421
422
423
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
424
    ),
425
426
427
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
428
    ),
zxy's avatar
zxy committed
429
430
431
432
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
433
434
435
436
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
437
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
438
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
439
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
440
441
442
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
443
    ),
444
445
446
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),
    "MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"),
447
448
449
450
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
451
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
452
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),
453
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
454
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
455
456
457
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
458
    ),
459
460
461
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
462
    ),
463
464
465
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
466
    ),
467
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),
468
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
469
470
471
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
472
    ),
473
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
474
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
475
476
477
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
478
    ),
479
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
480
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
481
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
482
483
    "NemotronH_Nano_Omni_Reasoning_V3": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
    "NemotronH_Super_Omni_Reasoning_V3": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
484
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
485
    "OpenCUAForConditionalGeneration": ("opencua", "OpenCUAForConditionalGeneration"),
486
487
488
489
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
490
    "Ovis": ("ovis", "Ovis"),
491
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
492
493
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
494
495
496
497
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
498
499
500
501
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
502
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
503
    "Phi4ForCausalLMV": ("phi4siglip", "Phi4ForCausalLMV"),
504
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
505
506
507
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
508
509
510
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
511
    ),
512
513
514
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
515
    ),
516
517
518
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
519
    ),
520
521
522
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
523
    ),
524
525
526
527
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
528
529
530
531
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
532
533
    "Qwen3ASRRealtimeGeneration": ("qwen3_asr_realtime", "Qwen3ASRRealtimeGeneration"),
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
534
535
536
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
537
    ),
538
    "Qwen3_5ForConditionalGeneration": ("qwen3_5", "Qwen3_5ForConditionalGeneration"),
539
540
541
542
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
543
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
544
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
545
546
547
548
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),
549
550
551
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
552
    ),
553
    "UltravoxModel": ("ultravox", "UltravoxModel"),
554
555
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),
556
    # [Encoder-decoder]
557
    "CohereAsrForConditionalGeneration": (
Ekagra Ranjan's avatar
Ekagra Ranjan committed
558
        "cohere_asr",
559
        "CohereAsrForConditionalGeneration",
Ekagra Ranjan's avatar
Ekagra Ranjan committed
560
    ),
561
562
563
564
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
565
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),
566
}
567
568

_SPECULATIVE_DECODING_MODELS = {
569
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
570
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
571
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
572
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
573
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
574
    "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
575
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
576
    "Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
577
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
578
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
579
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
580
581
582
583
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
584
585
    "Eagle3DeepseekV2ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
    "Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
586
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
587
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
588
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
589
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
590
    "Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"),
591
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
592
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
593
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
594
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
595
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
596
    "MedusaModel": ("medusa", "Medusa"),
597
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
598
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
599
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
600
601
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
602
603
604
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
605
}
606

607
_TRANSFORMERS_SUPPORTED_MODELS = {
608
609
610
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
611
612
613
614
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
615
616
617
}

_TRANSFORMERS_BACKEND_MODELS = {
618
    # Text generation models
619
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
638
    "TransformersForSequenceClassification": (
639
        "transformers",
640
        "TransformersForSequenceClassification",
641
    ),
642
    "TransformersMoEForSequenceClassification": (
643
        "transformers",
644
        "TransformersMoEForSequenceClassification",
645
    ),
646
647
648
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
649
    ),
650
}
651

652
_VLLM_MODELS = {
653
    **_TEXT_GENERATION_MODELS,
654
    **_EMBEDDING_MODELS,
655
656
657
658
    **_LATE_INTERACTION_MODELS,
    **_REWARD_MODELS,
    **_TOKEN_CLASSIFICATION_MODELS,
    **_SEQUENCE_CLASSIFICATION_MODELS,
659
    **_MULTIMODAL_MODELS,
660
    **_SPECULATIVE_DECODING_MODELS,
661
662
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
663
664
}

665
666
667
668
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
669
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
670

671
_PREVIOUSLY_SUPPORTED_MODELS = {
672
    "MotifForCausalLM": "0.10.2",
673
    "Phi3SmallForCausalLM": "0.9.2",
674
    "Phi4FlashForCausalLM": "0.10.2",
675
    "Phi4MultimodalForCausalLM": "0.12.0",
676
677
678
679
680
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "DonutForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
681

682
683
684
685
686
687
688
_OOT_SUPPORTED_MODELS = {
    "BartModel": "https://github.com/vllm-project/bart-plugin",
    "BartForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
    "Florence2ForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
    "MBartForConditionalGeneration": "https://github.com/vllm-project/bart-plugin",
}

689

690
691
@dataclass(frozen=True)
class _ModelInfo:
692
    architecture: str
693
    is_text_generation_model: bool
694
    is_pooling_model: bool
695
    attn_type: AttnTypeStr
696
697
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
698
    score_type: ScoreType
699
    supports_multimodal: bool
700
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
701
    requires_raw_input_tokens: bool
702
    supports_multimodal_encoder_tp_data: bool
703
    supports_pp: bool
704
705
    has_inner_state: bool
    is_attention_free: bool
706
    is_hybrid: bool
707
    has_noops: bool
708
    supports_mamba_prefix_caching: bool
709
    supports_transcription: bool
710
    supports_transcription_only: bool
711
712

    @staticmethod
713
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
714
        return _ModelInfo(
715
            architecture=model.__name__,
716
            is_text_generation_model=is_text_generation_model(model),
717
            is_pooling_model=is_pooling_model(model),
718
719
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
720
            attn_type=get_attn_type(model),
721
            score_type=get_score_type(model),
722
            supports_multimodal=supports_multimodal(model),
723
724
725
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
726
            requires_raw_input_tokens=requires_raw_input_tokens(model),
727
728
729
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
730
            supports_pp=supports_pp(model),
731
732
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
733
            is_hybrid=is_hybrid(model),
734
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
735
            supports_transcription=supports_transcription(model),
736
737
738
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
739
            has_noops=has_noops(model),
740
        )
741
742


743
744
745
746
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
747

748
    @abstractmethod
749
    def load_model_cls(self) -> type[nn.Module]:
750
        raise NotImplementedError
751
752


753
754
755
756
757
758
759
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
760
    model_cls: type[nn.Module]
761
762

    @staticmethod
763
    def from_model_cls(model_cls: type[nn.Module]):
764
765
766
767
768
769
770
771
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

772
    def load_model_cls(self) -> type[nn.Module]:
773
774
775
776
777
778
779
780
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
781

782
783
784
    module_name: str
    class_name: str

785
786
787
788
789
790
791
792
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

793
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
794
795
        try:
            try:
796
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
797
798
799
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
800
                logger.debug(
801
                    "Cached model info file for class %s.%s not found",
802
803
804
                    self.module_name,
                    self.class_name,
                )
805
806
807
                return None

            if mi_dict["hash"] != module_hash:
808
                logger.debug(
809
                    "Cached model info file for class %s.%s is stale",
810
811
812
                    self.module_name,
                    self.class_name,
                )
813
814
815
816
817
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
818
            logger.debug(
819
                "Cached model info for class %s.%s error. ",
820
821
822
                self.module_name,
                self.class_name,
            )
823
824
            return None

825
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
826
827
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
828

829
830
831
832
833
834
835
836
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
837
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
838
839
840
841
842
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
843
    def inspect_model_cls(self) -> _ModelInfo:
844
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
845
        module_hash = None
846

847
848
        if model_path.exists():
            with open(model_path, "rb") as f:
849
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
850
851
852

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
853
                logger.debug(
854
                    "Loaded model info for class %s.%s from cache",
855
856
857
                    self.module_name,
                    self.class_name,
                )
858
859
                return mi
            else:
860
                logger.debug(
861
                    "Cache model info for class %s.%s miss. Loading model instead.",
862
863
864
                    self.module_name,
                    self.class_name,
                )
865
866
867

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
868
869
870
871
872
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
873
874

        # save cache file
875
876
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
877
878

        return mi
879

880
    def load_model_cls(self) -> type[nn.Module]:
881
882
883
884
885
886
887
888
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
889
) -> type[nn.Module] | None:
890
    from vllm.platforms import current_platform
891

892
    current_platform.verify_model_arch(model_arch)
893
894
895
    try:
        return model.load_model_cls()
    except Exception:
896
        logger.exception("Error in loading model architecture '%s'", model_arch)
897
        return None
898
899


900
901
902
903
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
904
) -> _ModelInfo | None:
905
906
907
    try:
        return model.inspect_model_cls()
    except Exception:
908
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
909
        return None
910
911


912
913
914
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
915
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
916

917
    def get_supported_archs(self) -> Set[str]:
918
        return self.models.keys()
919

920
921
922
    def register_model(
        self,
        model_arch: str,
923
        model_cls: type[nn.Module] | str,
924
    ) -> None:
925
926
927
        """
        Register an external model to be used in vLLM.

928
        `model_cls` can be either:
929

930
        - A [`torch.nn.Module`][] class directly referencing the model.
931
        - A string in the format `<module>:<class>` which can be used to
932
933
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
934
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
935
        """
936
937
938
939
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

940
        if model_arch in self.models:
941
942
            logger.warning(
                "Model architecture %s is already registered, and will be "
943
944
945
946
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
947
948
949
950
951
952

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
953

954
            model = _LazyRegisteredModel(*split_str)
955
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
956
            model = _RegisteredModel.from_model_cls(model_cls)
957
        else:
958
959
960
961
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
962
            raise TypeError(msg)
963

964
        self.models[model_arch] = model
965

966
    def _raise_for_unsupported(self, architectures: list[str]):
967
        all_supported_archs = self.get_supported_archs()
968

969
970
971
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
972
973
                "to be inspected. Please check the logs for more details."
            )
974

975
976
977
978
979
980
981
982
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
983
984
                    "use this model architecture."
                )
985
986
987
988
989
990
991
992
            if arch in _OOT_SUPPORTED_MODELS:
                plugin_url = _OOT_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} is not supported in-tree anymore. "
                    f"Please install the plugin at {plugin_url} if you want to "
                    "use this model architecture."
                )
993

994
995
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
996
997
            f"Supported architectures: {all_supported_archs}"
        )
998

999
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
1000
1001
        if model_arch not in self.models:
            return None
1002

1003
        return _try_load_model_cls(model_arch, self.models[model_arch])
1004

1005
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
1006
1007
        if model_arch not in self.models:
            return None
1008

1009
1010
1011
1012
1013
1014
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
1015
    ) -> str | None:
1016
1017
1018
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

1019
1020
1021
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
1038
                        trust_remote_code=model_config.trust_remote_code,
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
1051
                        trust_remote_code=model_config.trust_remote_code,
1052
1053
1054
1055
1056
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
1057
                if model_config.model_impl != "transformers":
1058
1059
1060
1061
1062
1063
1064
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
1065
1066
                    "'auto_map' (relevant if the model is custom)."
                )
1067
1068

        if not model_module.is_backend_compatible():
1069
            if model_config.model_impl != "transformers":
1070
                return None
1071

1072
1073
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1074
1075
                "is not compatible with vLLM."
            )
1076

1077
        return model_config._get_transformers_backend_cls()
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1104

1105
1106
    def inspect_model_cls(
        self,
1107
        architectures: str | list[str],
1108
        model_config: ModelConfig,
1109
    ) -> tuple[_ModelInfo, str]:
1110
1111
        if isinstance(architectures, str):
            architectures = [architectures]
1112
1113
        if not architectures:
            raise ValueError("No model architectures are specified")
1114
1115

        # Require transformers impl
1116
        if model_config.model_impl == "transformers":
1117
            arch = self._try_resolve_transformers(architectures[0], model_config)
1118
1119
1120
1121
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1122
        elif model_config.model_impl == "terratorch":
1123
1124
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1125

1126
        # Fallback to transformers impl (after resolving convert_type)
1127
1128
1129
1130
1131
1132
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1133
1134
1135
1136
1137
1138
1139
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1140
            model_info = self._try_inspect_model_cls(normalized_arch)
1141
            if model_info is not None:
1142
                return (model_info, arch)
1143

1144
        # Fallback to transformers impl (before resolving runner_type)
1145
1146
1147
1148
1149
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1150
1151
1152
1153
1154
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1155
        return self._raise_for_unsupported(architectures)
1156

1157
1158
    def resolve_model_cls(
        self,
1159
        architectures: str | list[str],
1160
        model_config: ModelConfig,
1161
    ) -> tuple[type[nn.Module], str]:
1162
1163
        if isinstance(architectures, str):
            architectures = [architectures]
1164
1165
        if not architectures:
            raise ValueError("No model architectures are specified")
1166
1167

        # Require transformers impl
1168
        if model_config.model_impl == "transformers":
1169
            arch = self._try_resolve_transformers(architectures[0], model_config)
1170
1171
1172
1173
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1174
        elif model_config.model_impl == "terratorch":
1175
1176
1177
1178
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1179

1180
        # Fallback to transformers impl (after resolving convert_type)
1181
1182
1183
1184
1185
1186
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1187
1188
1189
1190
1191
1192
1193
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1194
            model_cls = self._try_load_model_cls(normalized_arch)
1195
1196
            if model_cls is not None:
                return (model_cls, arch)
1197

1198
        # Fallback to transformers impl (before resolving runner_type)
1199
1200
1201
1202
1203
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1204
1205
1206
1207
1208
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1209
        return self._raise_for_unsupported(architectures)
1210

1211
1212
    def is_text_generation_model(
        self,
1213
        architectures: str | list[str],
1214
        model_config: ModelConfig,
1215
    ) -> bool:
1216
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1217
        return model_cls.is_text_generation_model
1218

1219
    def is_pooling_model(
1220
        self,
1221
        architectures: str | list[str],
1222
        model_config: ModelConfig,
1223
    ) -> bool:
1224
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1225
        return model_cls.is_pooling_model
1226
1227
1228

    def is_multimodal_model(
        self,
1229
        architectures: str | list[str],
1230
        model_config: ModelConfig,
1231
    ) -> bool:
1232
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1233
        return model_cls.supports_multimodal
1234

1235
    def is_multimodal_raw_input_only_model(
1236
        self,
1237
        architectures: str | list[str],
1238
        model_config: ModelConfig,
1239
    ) -> bool:
1240
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1241
        return model_cls.supports_multimodal_raw_input_only
1242

1243
1244
    def is_pp_supported_model(
        self,
1245
        architectures: str | list[str],
1246
        model_config: ModelConfig,
1247
    ) -> bool:
1248
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1249
        return model_cls.supports_pp
1250

1251
1252
    def model_has_inner_state(
        self,
1253
        architectures: str | list[str],
1254
        model_config: ModelConfig,
1255
    ) -> bool:
1256
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1257
        return model_cls.has_inner_state
1258

1259
1260
    def is_attention_free_model(
        self,
1261
        architectures: str | list[str],
1262
        model_config: ModelConfig,
1263
    ) -> bool:
1264
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1265
        return model_cls.is_attention_free
1266

1267
1268
    def is_hybrid_model(
        self,
1269
        architectures: str | list[str],
1270
        model_config: ModelConfig,
1271
    ) -> bool:
1272
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1273
1274
        return model_cls.is_hybrid

1275
1276
    def is_noops_model(
        self,
1277
        architectures: str | list[str],
1278
        model_config: ModelConfig,
1279
    ) -> bool:
1280
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1281
1282
        return model_cls.has_noops

1283
1284
    def is_transcription_model(
        self,
1285
        architectures: str | list[str],
1286
        model_config: ModelConfig,
1287
    ) -> bool:
1288
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1289
1290
        return model_cls.supports_transcription

1291
1292
    def is_transcription_only_model(
        self,
1293
        architectures: str | list[str],
1294
        model_config: ModelConfig,
1295
    ) -> bool:
1296
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1297
1298
        return model_cls.supports_transcription_only

1299

1300
1301
1302
1303
1304
1305
1306
1307
1308
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1309
1310
1311
1312
1313

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1314
1315
1316
1317
1318
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1319
        # `cloudpickle` allows pickling lambda functions directly
1320
        import cloudpickle
1321

1322
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1323
1324
1325

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1326
1327
1328
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1329
1330
1331
1332
1333
1334

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1335
1336
1337
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1338

1339
        with open(output_filepath, "rb") as f:
1340
1341
1342
1343
1344
1345
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1346

1347
1348
1349
1350
1351
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1352
1353
1354

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1355
1356
1357


if __name__ == "__main__":
1358
    _run()