registry.py 51.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
113
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
114
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
115
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
117
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
118
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
119
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
120
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
121
122
123
124
125
126
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
127
128
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
129
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
130
131
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
132
133
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
134
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
135
136
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
137
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
138
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
139
140
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
141
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
142
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
143
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
144
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
145
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
146
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
147
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
149
150
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
151
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
152
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
153
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
154
155
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
156
157
158
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
159
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
160
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
161
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
162
163
164
165
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
166
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
167
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
168
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
169
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
171
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
172
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
174
    "OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
175
176
177
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
178
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
179
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
180
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
181
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
182
183
184
185
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
186
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
187
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
188
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
189
190
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
191
192
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
193
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
194
195
    "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),
    "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),
196
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
197
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
198
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
199
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
200
201
202
203
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
204
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
205
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
206
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
207
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
208
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
209
210
211
}

_EMBEDDING_MODELS = {
212
    # [Text-only]
213
    "BertModel": ("bert", "BertEmbeddingModel"),
214
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
215
    "HF_ColBERT": ("colbert", "ColBERTModel"),
216
217
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
218
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
219
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
220
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
221
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
222
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
223
    "GritLM": ("gritlm", "GritLM"),
224
225
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
226
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
227
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
228
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
229
    "LlamaModel": ("llama", "LlamaForCausalLM"),
230
231
    **{
        # Multiple models share the same architecture, so we include them all
232
233
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
234
235
        if arch == "LlamaForCausalLM"
    },
236
    "MistralModel": ("llama", "LlamaForCausalLM"),
237
    "ModernBertModel": ("modernbert", "ModernBertModel"),
238
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
239
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
240
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
241
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
242
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
243
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
244
245
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
246
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
247
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
248
249
250
251
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
252
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
253
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
254
    # [Multimodal]
255
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
256
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
257
258
259
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
260
    ),
261
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
262
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
263
264
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
265
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
266
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
267
268
269
270
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
271
272
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
273
    # models for the time being.
274
275
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
276
277
}

278
279
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
280
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
281
282
283
284
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
285
286
287
288
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
289
290
291
292
    ),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
293
    ),
294
295
296
297
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
298
299
300
301
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
302
303
304
305
306
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
307
308
}

309
_MULTIMODAL_MODELS = {
310
    # [Decoder-only]
311
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
312
313
314
315
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
316
317
318
319
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
320
321
322
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
323
    ),
324
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
325
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
326
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
327
328
329
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
330
    ),
331
332
333
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
334
    ),
335
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
336
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
337
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
338
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
339
340
341
342
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
343
344
345
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
346
    ),
347
348
349
350
    "FireRedASR2ForConditionalGeneration": (
        "fireredasr2",
        "FireRedASR2ForConditionalGeneration",
    ),
351
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
352
353
354
355
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
356
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
357
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
358
359
360
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
361
    ),
362
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
363
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
364
365
366
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
367
368
369
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
370
    ),
371
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
372
373
374
375
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
376
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
377
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
378
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
379
380
381
382
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
383
384
385
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
386
    ),
387
388
389
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
390
    ),
zxy's avatar
zxy committed
391
392
393
394
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
395
396
397
398
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
399
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
400
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
401
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
402
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
403
404
405
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
406
    ),
407
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
408
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
409
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
410
411
412
413
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
414
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
415
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
416
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
417
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
418
419
420
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
421
    ),
422
423
424
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
425
    ),
426
427
428
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
429
    ),
430
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
431
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
432
433
434
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
435
    ),
436
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
437
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
438
439
440
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
441
    ),
442
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
443
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
444
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
445
446
447
448
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
449
    "Ovis": ("ovis", "Ovis"),
450
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
451
452
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
453
454
455
456
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
457
458
459
460
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
461
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
462
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
463
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
464
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
465
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
466
467
468
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
469
    ),
470
471
472
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
473
    ),
474
475
476
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
477
    ),
478
479
480
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
481
    ),
482
483
484
485
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
486
487
488
489
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
490
491
492
493
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
494
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
495
496
497
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
498
    ),
499
500
501
502
503
504
505
506
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
507
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
508
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
509
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
510
511
512
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
513
    ),
514
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
515
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
516
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
517
    # [Encoder-decoder]
518
519
520
521
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
522
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
523
}
524
525

_SPECULATIVE_DECODING_MODELS = {
526
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
527
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
528
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
529
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
530
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
531
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
532
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
533
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
534
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
535
536
537
538
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
539
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
540
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
541
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
542
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
543
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
544
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
545
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
546
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
547
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
548
    "MedusaModel": ("medusa", "Medusa"),
549
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
550
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
551
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
552
553
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
554
555
556
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
557
}
558

559
_TRANSFORMERS_SUPPORTED_MODELS = {
560
561
562
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
563
564
565
566
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
567
568
569
}

_TRANSFORMERS_BACKEND_MODELS = {
570
    # Text generation models
571
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
590
    "TransformersForSequenceClassification": (
591
        "transformers",
592
        "TransformersForSequenceClassification",
593
    ),
594
    "TransformersMoEForSequenceClassification": (
595
        "transformers",
596
        "TransformersMoEForSequenceClassification",
597
    ),
598
599
600
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
601
    ),
602
}
603

604
_VLLM_MODELS = {
605
    **_TEXT_GENERATION_MODELS,
606
    **_EMBEDDING_MODELS,
607
    **_CROSS_ENCODER_MODELS,
608
    **_MULTIMODAL_MODELS,
609
    **_SPECULATIVE_DECODING_MODELS,
610
611
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
612
613
}

614
615
616
617
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
618
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
619

620
_PREVIOUSLY_SUPPORTED_MODELS = {
621
    "MotifForCausalLM": "0.10.2",
622
    "Phi3SmallForCausalLM": "0.9.2",
623
    "Phi4FlashForCausalLM": "0.10.2",
624
    "Phi4MultimodalForCausalLM": "0.12.0",
625
626
627
628
629
630
631
632
633
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
634

635

636
637
@dataclass(frozen=True)
class _ModelInfo:
638
    architecture: str
639
    is_text_generation_model: bool
640
    is_pooling_model: bool
641
    attn_type: AttnTypeStr
642
643
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
644
    supports_cross_encoding: bool
645
    supports_late_interaction: bool
646
    supports_multimodal: bool
647
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
648
    requires_raw_input_tokens: bool
649
    supports_multimodal_encoder_tp_data: bool
650
    supports_pp: bool
651
652
    has_inner_state: bool
    is_attention_free: bool
653
    is_hybrid: bool
654
    has_noops: bool
655
    supports_mamba_prefix_caching: bool
656
    supports_transcription: bool
657
    supports_transcription_only: bool
658
659

    @staticmethod
660
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
661
        return _ModelInfo(
662
            architecture=model.__name__,
663
            is_text_generation_model=is_text_generation_model(model),
664
            is_pooling_model=is_pooling_model(model),
665
666
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
667
            attn_type=get_attn_type(model),
668
            supports_cross_encoding=supports_cross_encoding(model),
669
            supports_late_interaction=supports_late_interaction(model),
670
            supports_multimodal=supports_multimodal(model),
671
672
673
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
674
            requires_raw_input_tokens=requires_raw_input_tokens(model),
675
676
677
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
678
            supports_pp=supports_pp(model),
679
680
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
681
            is_hybrid=is_hybrid(model),
682
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
683
            supports_transcription=supports_transcription(model),
684
685
686
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
687
            has_noops=has_noops(model),
688
        )
689
690


691
692
693
694
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
695

696
    @abstractmethod
697
    def load_model_cls(self) -> type[nn.Module]:
698
        raise NotImplementedError
699
700


701
702
703
704
705
706
707
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
708
    model_cls: type[nn.Module]
709
710

    @staticmethod
711
    def from_model_cls(model_cls: type[nn.Module]):
712
713
714
715
716
717
718
719
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

720
    def load_model_cls(self) -> type[nn.Module]:
721
722
723
724
725
726
727
728
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
729

730
731
732
    module_name: str
    class_name: str

733
734
735
736
737
738
739
740
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

741
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
742
743
        try:
            try:
744
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
745
746
747
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
748
                logger.debug(
749
                    "Cached model info file for class %s.%s not found",
750
751
752
                    self.module_name,
                    self.class_name,
                )
753
754
755
                return None

            if mi_dict["hash"] != module_hash:
756
                logger.debug(
757
                    "Cached model info file for class %s.%s is stale",
758
759
760
                    self.module_name,
                    self.class_name,
                )
761
762
763
764
765
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
766
            logger.debug(
767
                "Cached model info for class %s.%s error. ",
768
769
770
                self.module_name,
                self.class_name,
            )
771
772
            return None

773
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
774
775
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
776

777
778
779
780
781
782
783
784
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
785
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
786
787
788
789
790
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
791
    def inspect_model_cls(self) -> _ModelInfo:
792
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
793
        module_hash = None
794

795
796
        if model_path.exists():
            with open(model_path, "rb") as f:
797
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
798
799
800

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
801
                logger.debug(
802
                    "Loaded model info for class %s.%s from cache",
803
804
805
                    self.module_name,
                    self.class_name,
                )
806
807
                return mi
            else:
808
                logger.debug(
809
                    "Cache model info for class %s.%s miss. Loading model instead.",
810
811
812
                    self.module_name,
                    self.class_name,
                )
813
814
815

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
816
817
818
819
820
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
821
822

        # save cache file
823
824
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
825
826

        return mi
827

828
    def load_model_cls(self) -> type[nn.Module]:
829
830
831
832
833
834
835
836
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
837
) -> type[nn.Module] | None:
838
    from vllm.platforms import current_platform
839

840
    current_platform.verify_model_arch(model_arch)
841
842
843
    try:
        return model.load_model_cls()
    except Exception:
844
        logger.exception("Error in loading model architecture '%s'", model_arch)
845
        return None
846
847


848
849
850
851
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
852
) -> _ModelInfo | None:
853
854
855
    try:
        return model.inspect_model_cls()
    except Exception:
856
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
857
        return None
858
859


860
861
862
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
863
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
864

865
    def get_supported_archs(self) -> Set[str]:
866
        return self.models.keys()
867

868
869
870
    def register_model(
        self,
        model_arch: str,
871
        model_cls: type[nn.Module] | str,
872
    ) -> None:
873
874
875
        """
        Register an external model to be used in vLLM.

876
        `model_cls` can be either:
877

878
        - A [`torch.nn.Module`][] class directly referencing the model.
879
        - A string in the format `<module>:<class>` which can be used to
880
881
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
882
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
883
        """
884
885
886
887
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

888
        if model_arch in self.models:
889
890
            logger.warning(
                "Model architecture %s is already registered, and will be "
891
892
893
894
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
895
896
897
898
899
900

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
901

902
            model = _LazyRegisteredModel(*split_str)
903
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
904
            model = _RegisteredModel.from_model_cls(model_cls)
905
        else:
906
907
908
909
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
910
            raise TypeError(msg)
911

912
        self.models[model_arch] = model
913

914
    def _raise_for_unsupported(self, architectures: list[str]):
915
        all_supported_archs = self.get_supported_archs()
916

917
918
919
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
920
921
                "to be inspected. Please check the logs for more details."
            )
922

923
924
925
926
927
928
929
930
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
931
932
                    "use this model architecture."
                )
933

934
935
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
936
937
            f"Supported architectures: {all_supported_archs}"
        )
938

939
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
940
941
        if model_arch not in self.models:
            return None
942

943
        return _try_load_model_cls(model_arch, self.models[model_arch])
944

945
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
946
947
        if model_arch not in self.models:
            return None
948

949
950
951
952
953
954
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
955
    ) -> str | None:
956
957
958
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

959
960
961
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
978
                        trust_remote_code=model_config.trust_remote_code,
979
980
981
982
983
984
985
986
987
988
989
990
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
991
                        trust_remote_code=model_config.trust_remote_code,
992
993
994
995
996
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
997
                if model_config.model_impl != "transformers":
998
999
1000
1001
1002
1003
1004
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
1005
1006
                    "'auto_map' (relevant if the model is custom)."
                )
1007
1008

        if not model_module.is_backend_compatible():
1009
            if model_config.model_impl != "transformers":
1010
                return None
1011

1012
1013
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1014
1015
                "is not compatible with vLLM."
            )
1016

1017
        return model_config._get_transformers_backend_cls()
1018

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1044

1045
1046
    def inspect_model_cls(
        self,
1047
        architectures: str | list[str],
1048
        model_config: ModelConfig,
1049
    ) -> tuple[_ModelInfo, str]:
1050
1051
        if isinstance(architectures, str):
            architectures = [architectures]
1052
1053
        if not architectures:
            raise ValueError("No model architectures are specified")
1054
1055

        # Require transformers impl
1056
        if model_config.model_impl == "transformers":
1057
            arch = self._try_resolve_transformers(architectures[0], model_config)
1058
1059
1060
1061
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1062
        elif model_config.model_impl == "terratorch":
1063
1064
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1065

1066
        # Fallback to transformers impl (after resolving convert_type)
1067
1068
1069
1070
1071
1072
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1073
1074
1075
1076
1077
1078
1079
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1080
            model_info = self._try_inspect_model_cls(normalized_arch)
1081
            if model_info is not None:
1082
                return (model_info, arch)
1083

1084
        # Fallback to transformers impl (before resolving runner_type)
1085
1086
1087
1088
1089
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1090
1091
1092
1093
1094
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1095
        return self._raise_for_unsupported(architectures)
1096

1097
1098
    def resolve_model_cls(
        self,
1099
        architectures: str | list[str],
1100
        model_config: ModelConfig,
1101
    ) -> tuple[type[nn.Module], str]:
1102
1103
        if isinstance(architectures, str):
            architectures = [architectures]
1104
1105
        if not architectures:
            raise ValueError("No model architectures are specified")
1106
1107

        # Require transformers impl
1108
        if model_config.model_impl == "transformers":
1109
            arch = self._try_resolve_transformers(architectures[0], model_config)
1110
1111
1112
1113
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1114
        elif model_config.model_impl == "terratorch":
1115
1116
1117
1118
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1119

1120
        # Fallback to transformers impl (after resolving convert_type)
1121
1122
1123
1124
1125
1126
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1127
1128
1129
1130
1131
1132
1133
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1134
            model_cls = self._try_load_model_cls(normalized_arch)
1135
1136
            if model_cls is not None:
                return (model_cls, arch)
1137

1138
        # Fallback to transformers impl (before resolving runner_type)
1139
1140
1141
1142
1143
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1144
1145
1146
1147
1148
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1149
        return self._raise_for_unsupported(architectures)
1150

1151
1152
    def is_text_generation_model(
        self,
1153
        architectures: str | list[str],
1154
        model_config: ModelConfig,
1155
    ) -> bool:
1156
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1157
        return model_cls.is_text_generation_model
1158

1159
    def is_pooling_model(
1160
        self,
1161
        architectures: str | list[str],
1162
        model_config: ModelConfig,
1163
    ) -> bool:
1164
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1165
        return model_cls.is_pooling_model
1166

1167
1168
    def is_cross_encoder_model(
        self,
1169
        architectures: str | list[str],
1170
        model_config: ModelConfig,
1171
    ) -> bool:
1172
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1173
        return model_cls.supports_cross_encoding
1174

1175
1176
    def is_multimodal_model(
        self,
1177
        architectures: str | list[str],
1178
        model_config: ModelConfig,
1179
    ) -> bool:
1180
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1181
        return model_cls.supports_multimodal
1182

1183
    def is_multimodal_raw_input_only_model(
1184
        self,
1185
        architectures: str | list[str],
1186
        model_config: ModelConfig,
1187
    ) -> bool:
1188
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1189
        return model_cls.supports_multimodal_raw_input_only
1190

1191
1192
    def is_pp_supported_model(
        self,
1193
        architectures: str | list[str],
1194
        model_config: ModelConfig,
1195
    ) -> bool:
1196
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1197
        return model_cls.supports_pp
1198

1199
1200
    def model_has_inner_state(
        self,
1201
        architectures: str | list[str],
1202
        model_config: ModelConfig,
1203
    ) -> bool:
1204
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1205
        return model_cls.has_inner_state
1206

1207
1208
    def is_attention_free_model(
        self,
1209
        architectures: str | list[str],
1210
        model_config: ModelConfig,
1211
    ) -> bool:
1212
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1213
        return model_cls.is_attention_free
1214

1215
1216
    def is_hybrid_model(
        self,
1217
        architectures: str | list[str],
1218
        model_config: ModelConfig,
1219
    ) -> bool:
1220
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1221
1222
        return model_cls.is_hybrid

1223
1224
    def is_noops_model(
        self,
1225
        architectures: str | list[str],
1226
        model_config: ModelConfig,
1227
    ) -> bool:
1228
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1229
1230
        return model_cls.has_noops

1231
1232
    def is_transcription_model(
        self,
1233
        architectures: str | list[str],
1234
        model_config: ModelConfig,
1235
    ) -> bool:
1236
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1237
1238
        return model_cls.supports_transcription

1239
1240
    def is_transcription_only_model(
        self,
1241
        architectures: str | list[str],
1242
        model_config: ModelConfig,
1243
    ) -> bool:
1244
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1245
1246
        return model_cls.supports_transcription_only

1247

1248
1249
1250
1251
1252
1253
1254
1255
1256
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1257
1258
1259
1260
1261

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1262
1263
1264
1265
1266
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1267
        # `cloudpickle` allows pickling lambda functions directly
1268
        import cloudpickle
1269

1270
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1271
1272
1273

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1274
1275
1276
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1277
1278
1279
1280
1281
1282

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1283
1284
1285
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1286

1287
        with open(output_filepath, "rb") as f:
1288
1289
1290
1291
1292
1293
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1294

1295
1296
1297
1298
1299
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1300
1301
1302

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1303
1304
1305


if __name__ == "__main__":
1306
    _run()