registry.py 50.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
84
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
85
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
86
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
87
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
88
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
89
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
92
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
93
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
94
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
95
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
96
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
99
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
100
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
101
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
102
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
103
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
104
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
105
106
107
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
108
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
109
110
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
111
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
112
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
113
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
114
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
117
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
118
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
119
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
120
121
122
123
124
125
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
126
127
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
128
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
129
130
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
131
132
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
133
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
134
135
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
136
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
137
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
138
139
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
140
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
141
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
142
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
143
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
144
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
145
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
146
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
147
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
148
149
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
150
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
151
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
152
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
153
154
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
155
156
157
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
158
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
159
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
160
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
161
162
163
164
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
165
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
166
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
167
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
168
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
171
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
174
175
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
176
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
177
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
178
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
179
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
180
181
182
183
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
184
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
185
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
186
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
187
188
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
189
190
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
191
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
192
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
193
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
194
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
195
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
196
197
198
199
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
200
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
203
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
204
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
205
206
207
}

_EMBEDDING_MODELS = {
208
    # [Text-only]
209
    "BertModel": ("bert", "BertEmbeddingModel"),
210
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
211
    "HF_ColBERT": ("colbert", "ColBERTModel"),
212
213
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
214
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
215
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
216
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
217
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
218
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
219
    "GritLM": ("gritlm", "GritLM"),
220
221
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
222
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
223
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
224
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
225
    "LlamaModel": ("llama", "LlamaForCausalLM"),
226
227
    **{
        # Multiple models share the same architecture, so we include them all
228
229
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
230
231
        if arch == "LlamaForCausalLM"
    },
232
    "MistralModel": ("llama", "LlamaForCausalLM"),
233
    "ModernBertModel": ("modernbert", "ModernBertModel"),
234
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
235
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
236
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
237
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
239
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
240
241
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
242
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
243
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
244
245
246
247
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
248
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
249
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
250
    # [Multimodal]
251
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
252
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
253
254
255
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
256
    ),
257
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
258
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
259
260
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
261
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
262
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
263
264
265
266
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
267
268
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
269
    # models for the time being.
270
271
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
272
273
}

274
275
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
276
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
277
278
279
280
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
281
282
283
284
285
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
286
287
288
289
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
290
291
292
293
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
294
295
296
297
298
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
299
300
}

301
_MULTIMODAL_MODELS = {
302
    # [Decoder-only]
303
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
304
305
306
307
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
308
309
310
311
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
312
313
314
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
315
    ),
316
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
317
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
318
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
319
320
321
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
322
    ),
323
324
325
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
326
    ),
327
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
328
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
329
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
330
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
331
332
333
334
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
335
336
337
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
338
    ),
339
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
340
341
342
343
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
344
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
345
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
346
347
348
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
349
    ),
350
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
351
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
352
353
354
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
355
356
357
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
358
    ),
359
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
360
361
362
363
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
364
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
365
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
366
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
367
368
369
370
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
371
372
373
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
374
    ),
375
376
377
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
378
    ),
zxy's avatar
zxy committed
379
380
381
382
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
383
384
385
386
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
387
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
388
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
389
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
390
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
391
392
393
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
394
    ),
395
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
396
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
397
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
398
399
400
401
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
402
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
403
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
404
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
405
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
406
407
408
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
409
    ),
410
411
412
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
413
    ),
414
415
416
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
417
    ),
418
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
419
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
420
421
422
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
423
    ),
424
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
425
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
426
427
428
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
429
    ),
430
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
431
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
432
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
433
434
435
436
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
437
    "Ovis": ("ovis", "Ovis"),
438
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
439
440
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
441
442
443
444
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
445
446
447
448
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
449
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
450
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
451
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
452
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
453
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
454
455
456
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
457
    ),
458
459
460
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
461
    ),
462
463
464
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
465
    ),
466
467
468
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
469
    ),
470
471
472
473
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
474
475
476
477
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
478
479
480
481
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
482
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
483
484
485
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
486
    ),
487
488
489
490
491
492
493
494
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
495
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
496
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
497
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
498
499
500
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
501
    ),
502
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
503
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
504
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
505
    # [Encoder-decoder]
506
507
508
509
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
510
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
511
}
512
513

_SPECULATIVE_DECODING_MODELS = {
514
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
515
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
516
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
517
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
518
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
519
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
520
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
521
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
522
523
524
525
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
526
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
527
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
528
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
529
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
530
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
531
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
532
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
533
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
534
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
535
    "MedusaModel": ("medusa", "Medusa"),
536
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
537
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
538
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
539
540
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
541
542
543
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
544
}
545

546
_TRANSFORMERS_SUPPORTED_MODELS = {
547
548
549
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
550
551
552
553
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
554
555
556
}

_TRANSFORMERS_BACKEND_MODELS = {
557
    # Text generation models
558
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
577
    "TransformersForSequenceClassification": (
578
        "transformers",
579
        "TransformersForSequenceClassification",
580
    ),
581
    "TransformersMoEForSequenceClassification": (
582
        "transformers",
583
        "TransformersMoEForSequenceClassification",
584
    ),
585
586
587
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
588
    ),
589
}
590

591
_VLLM_MODELS = {
592
    **_TEXT_GENERATION_MODELS,
593
    **_EMBEDDING_MODELS,
594
    **_CROSS_ENCODER_MODELS,
595
    **_MULTIMODAL_MODELS,
596
    **_SPECULATIVE_DECODING_MODELS,
597
598
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
599
600
}

601
602
603
604
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
605
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
606

607
_PREVIOUSLY_SUPPORTED_MODELS = {
608
    "MotifForCausalLM": "0.10.2",
609
    "Phi3SmallForCausalLM": "0.9.2",
610
    "Phi4FlashForCausalLM": "0.10.2",
611
    "Phi4MultimodalForCausalLM": "0.12.0",
612
613
614
615
616
617
618
619
620
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
621

622

623
624
@dataclass(frozen=True)
class _ModelInfo:
625
    architecture: str
626
    is_text_generation_model: bool
627
    is_pooling_model: bool
628
    attn_type: AttnTypeStr
629
630
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
631
    supports_cross_encoding: bool
632
    supports_late_interaction: bool
633
    supports_multimodal: bool
634
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
635
    requires_raw_input_tokens: bool
636
    supports_multimodal_encoder_tp_data: bool
637
    supports_pp: bool
638
639
    has_inner_state: bool
    is_attention_free: bool
640
    is_hybrid: bool
641
    has_noops: bool
642
    supports_mamba_prefix_caching: bool
643
    supports_transcription: bool
644
    supports_transcription_only: bool
645
646

    @staticmethod
647
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
648
        return _ModelInfo(
649
            architecture=model.__name__,
650
            is_text_generation_model=is_text_generation_model(model),
651
            is_pooling_model=is_pooling_model(model),
652
653
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
654
            attn_type=get_attn_type(model),
655
            supports_cross_encoding=supports_cross_encoding(model),
656
            supports_late_interaction=supports_late_interaction(model),
657
            supports_multimodal=supports_multimodal(model),
658
659
660
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
661
            requires_raw_input_tokens=requires_raw_input_tokens(model),
662
663
664
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
665
            supports_pp=supports_pp(model),
666
667
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
668
            is_hybrid=is_hybrid(model),
669
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
670
            supports_transcription=supports_transcription(model),
671
672
673
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
674
            has_noops=has_noops(model),
675
        )
676
677


678
679
680
681
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
682

683
    @abstractmethod
684
    def load_model_cls(self) -> type[nn.Module]:
685
        raise NotImplementedError
686
687


688
689
690
691
692
693
694
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
695
    model_cls: type[nn.Module]
696
697

    @staticmethod
698
    def from_model_cls(model_cls: type[nn.Module]):
699
700
701
702
703
704
705
706
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

707
    def load_model_cls(self) -> type[nn.Module]:
708
709
710
711
712
713
714
715
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
716

717
718
719
    module_name: str
    class_name: str

720
721
722
723
724
725
726
727
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

728
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
729
730
        try:
            try:
731
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
732
733
734
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
735
                logger.debug(
736
                    "Cached model info file for class %s.%s not found",
737
738
739
                    self.module_name,
                    self.class_name,
                )
740
741
742
                return None

            if mi_dict["hash"] != module_hash:
743
                logger.debug(
744
                    "Cached model info file for class %s.%s is stale",
745
746
747
                    self.module_name,
                    self.class_name,
                )
748
749
750
751
752
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
753
            logger.debug(
754
                "Cached model info for class %s.%s error. ",
755
756
757
                self.module_name,
                self.class_name,
            )
758
759
            return None

760
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
761
762
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
763

764
765
766
767
768
769
770
771
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
772
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
773
774
775
776
777
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
778
    def inspect_model_cls(self) -> _ModelInfo:
779
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
780
        module_hash = None
781

782
783
        if model_path.exists():
            with open(model_path, "rb") as f:
784
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
785
786
787

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
788
                logger.debug(
789
                    "Loaded model info for class %s.%s from cache",
790
791
792
                    self.module_name,
                    self.class_name,
                )
793
794
                return mi
            else:
795
                logger.debug(
796
                    "Cache model info for class %s.%s miss. Loading model instead.",
797
798
799
                    self.module_name,
                    self.class_name,
                )
800
801
802

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
803
804
805
806
807
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
808
809

        # save cache file
810
811
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
812
813

        return mi
814

815
    def load_model_cls(self) -> type[nn.Module]:
816
817
818
819
820
821
822
823
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
824
) -> type[nn.Module] | None:
825
    from vllm.platforms import current_platform
826

827
    current_platform.verify_model_arch(model_arch)
828
829
830
    try:
        return model.load_model_cls()
    except Exception:
831
        logger.exception("Error in loading model architecture '%s'", model_arch)
832
        return None
833
834


835
836
837
838
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
839
) -> _ModelInfo | None:
840
841
842
    try:
        return model.inspect_model_cls()
    except Exception:
843
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
844
        return None
845
846


847
848
849
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
850
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
851

852
    def get_supported_archs(self) -> Set[str]:
853
        return self.models.keys()
854

855
856
857
    def register_model(
        self,
        model_arch: str,
858
        model_cls: type[nn.Module] | str,
859
    ) -> None:
860
861
862
        """
        Register an external model to be used in vLLM.

863
        `model_cls` can be either:
864

865
        - A [`torch.nn.Module`][] class directly referencing the model.
866
        - A string in the format `<module>:<class>` which can be used to
867
868
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
869
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
870
        """
871
872
873
874
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

875
        if model_arch in self.models:
876
877
            logger.warning(
                "Model architecture %s is already registered, and will be "
878
879
880
881
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
882
883
884
885
886
887

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
888

889
            model = _LazyRegisteredModel(*split_str)
890
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
891
            model = _RegisteredModel.from_model_cls(model_cls)
892
        else:
893
894
895
896
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
897
            raise TypeError(msg)
898

899
        self.models[model_arch] = model
900

901
    def _raise_for_unsupported(self, architectures: list[str]):
902
        all_supported_archs = self.get_supported_archs()
903

904
905
906
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
907
908
                "to be inspected. Please check the logs for more details."
            )
909

910
911
912
913
914
915
916
917
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
918
919
                    "use this model architecture."
                )
920

921
922
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
923
924
            f"Supported architectures: {all_supported_archs}"
        )
925

926
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
927
928
        if model_arch not in self.models:
            return None
929

930
        return _try_load_model_cls(model_arch, self.models[model_arch])
931

932
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
933
934
        if model_arch not in self.models:
            return None
935

936
937
938
939
940
941
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
942
    ) -> str | None:
943
944
945
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

946
947
948
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
965
                        trust_remote_code=model_config.trust_remote_code,
966
967
968
969
970
971
972
973
974
975
976
977
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
978
                        trust_remote_code=model_config.trust_remote_code,
979
980
981
982
983
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
984
                if model_config.model_impl != "transformers":
985
986
987
988
989
990
991
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
992
993
                    "'auto_map' (relevant if the model is custom)."
                )
994
995

        if not model_module.is_backend_compatible():
996
            if model_config.model_impl != "transformers":
997
                return None
998

999
1000
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1001
1002
                "is not compatible with vLLM."
            )
1003

1004
        return model_config._get_transformers_backend_cls()
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1031

1032
1033
    def inspect_model_cls(
        self,
1034
        architectures: str | list[str],
1035
        model_config: ModelConfig,
1036
    ) -> tuple[_ModelInfo, str]:
1037
1038
        if isinstance(architectures, str):
            architectures = [architectures]
1039
1040
        if not architectures:
            raise ValueError("No model architectures are specified")
1041
1042

        # Require transformers impl
1043
        if model_config.model_impl == "transformers":
1044
            arch = self._try_resolve_transformers(architectures[0], model_config)
1045
1046
1047
1048
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1049
        elif model_config.model_impl == "terratorch":
1050
1051
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1052

1053
        # Fallback to transformers impl (after resolving convert_type)
1054
1055
1056
1057
1058
1059
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1060
1061
1062
1063
1064
1065
1066
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1067
            model_info = self._try_inspect_model_cls(normalized_arch)
1068
            if model_info is not None:
1069
                return (model_info, arch)
1070

1071
        # Fallback to transformers impl (before resolving runner_type)
1072
1073
1074
1075
1076
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1077
1078
1079
1080
1081
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1082
        return self._raise_for_unsupported(architectures)
1083

1084
1085
    def resolve_model_cls(
        self,
1086
        architectures: str | list[str],
1087
        model_config: ModelConfig,
1088
    ) -> tuple[type[nn.Module], str]:
1089
1090
        if isinstance(architectures, str):
            architectures = [architectures]
1091
1092
        if not architectures:
            raise ValueError("No model architectures are specified")
1093
1094

        # Require transformers impl
1095
        if model_config.model_impl == "transformers":
1096
            arch = self._try_resolve_transformers(architectures[0], model_config)
1097
1098
1099
1100
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1101
        elif model_config.model_impl == "terratorch":
1102
1103
1104
1105
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1106

1107
        # Fallback to transformers impl (after resolving convert_type)
1108
1109
1110
1111
1112
1113
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1114
1115
1116
1117
1118
1119
1120
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1121
            model_cls = self._try_load_model_cls(normalized_arch)
1122
1123
            if model_cls is not None:
                return (model_cls, arch)
1124

1125
        # Fallback to transformers impl (before resolving runner_type)
1126
1127
1128
1129
1130
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1131
1132
1133
1134
1135
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1136
        return self._raise_for_unsupported(architectures)
1137

1138
1139
    def is_text_generation_model(
        self,
1140
        architectures: str | list[str],
1141
        model_config: ModelConfig,
1142
    ) -> bool:
1143
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1144
        return model_cls.is_text_generation_model
1145

1146
    def is_pooling_model(
1147
        self,
1148
        architectures: str | list[str],
1149
        model_config: ModelConfig,
1150
    ) -> bool:
1151
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1152
        return model_cls.is_pooling_model
1153

1154
1155
    def is_cross_encoder_model(
        self,
1156
        architectures: str | list[str],
1157
        model_config: ModelConfig,
1158
    ) -> bool:
1159
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1160
        return model_cls.supports_cross_encoding
1161

1162
1163
    def is_multimodal_model(
        self,
1164
        architectures: str | list[str],
1165
        model_config: ModelConfig,
1166
    ) -> bool:
1167
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1168
        return model_cls.supports_multimodal
1169

1170
    def is_multimodal_raw_input_only_model(
1171
        self,
1172
        architectures: str | list[str],
1173
        model_config: ModelConfig,
1174
    ) -> bool:
1175
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1176
        return model_cls.supports_multimodal_raw_input_only
1177

1178
1179
    def is_pp_supported_model(
        self,
1180
        architectures: str | list[str],
1181
        model_config: ModelConfig,
1182
    ) -> bool:
1183
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1184
        return model_cls.supports_pp
1185

1186
1187
    def model_has_inner_state(
        self,
1188
        architectures: str | list[str],
1189
        model_config: ModelConfig,
1190
    ) -> bool:
1191
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1192
        return model_cls.has_inner_state
1193

1194
1195
    def is_attention_free_model(
        self,
1196
        architectures: str | list[str],
1197
        model_config: ModelConfig,
1198
    ) -> bool:
1199
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1200
        return model_cls.is_attention_free
1201

1202
1203
    def is_hybrid_model(
        self,
1204
        architectures: str | list[str],
1205
        model_config: ModelConfig,
1206
    ) -> bool:
1207
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1208
1209
        return model_cls.is_hybrid

1210
1211
    def is_noops_model(
        self,
1212
        architectures: str | list[str],
1213
        model_config: ModelConfig,
1214
    ) -> bool:
1215
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1216
1217
        return model_cls.has_noops

1218
1219
    def is_transcription_model(
        self,
1220
        architectures: str | list[str],
1221
        model_config: ModelConfig,
1222
    ) -> bool:
1223
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1224
1225
        return model_cls.supports_transcription

1226
1227
    def is_transcription_only_model(
        self,
1228
        architectures: str | list[str],
1229
        model_config: ModelConfig,
1230
    ) -> bool:
1231
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1232
1233
        return model_cls.supports_transcription_only

1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1244
1245
1246
1247
1248

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1249
1250
1251
1252
1253
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1254
        # `cloudpickle` allows pickling lambda functions directly
1255
        import cloudpickle
1256

1257
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1258
1259
1260

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1261
1262
1263
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1264
1265
1266
1267
1268
1269

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1270
1271
1272
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1273

1274
        with open(output_filepath, "rb") as f:
1275
1276
1277
1278
1279
1280
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1281

1282
1283
1284
1285
1286
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1287
1288
1289

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1290
1291
1292


if __name__ == "__main__":
1293
    _run()