registry.py 50.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
113
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
114
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
115
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
117
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
118
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
119
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
120
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
121
122
123
124
125
126
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
127
128
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
129
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
130
131
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
132
133
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
134
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
135
136
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
137
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
138
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
139
140
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
141
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
142
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
143
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
144
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
145
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
146
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
147
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
149
150
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
151
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
152
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
153
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
154
155
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
156
157
158
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
159
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
160
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
161
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
162
163
164
165
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
166
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
167
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
168
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
169
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
171
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
172
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
174
175
176
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
177
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
178
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
179
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
180
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
181
182
183
184
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
185
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
186
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
187
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
188
189
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
190
191
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
192
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
193
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
194
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
195
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
196
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
197
198
199
200
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
201
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
204
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
205
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
206
207
208
}

_EMBEDDING_MODELS = {
209
    # [Text-only]
210
    "BertModel": ("bert", "BertEmbeddingModel"),
211
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
212
    "HF_ColBERT": ("colbert", "ColBERTModel"),
213
214
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
215
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
216
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
217
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
218
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
219
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
220
    "GritLM": ("gritlm", "GritLM"),
221
222
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
223
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
224
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
225
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
226
    "LlamaModel": ("llama", "LlamaForCausalLM"),
227
228
    **{
        # Multiple models share the same architecture, so we include them all
229
230
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
231
232
        if arch == "LlamaForCausalLM"
    },
233
    "MistralModel": ("llama", "LlamaForCausalLM"),
234
    "ModernBertModel": ("modernbert", "ModernBertModel"),
235
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
236
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
237
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
239
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
240
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
241
242
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
243
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
244
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
245
246
247
248
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
249
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
250
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
251
    # [Multimodal]
252
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
253
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
254
255
256
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
257
    ),
258
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
259
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
260
261
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
262
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
263
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
264
265
266
267
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
268
269
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
270
    # models for the time being.
271
272
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
273
274
}

275
276
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
277
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
278
279
280
281
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
282
283
284
285
286
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
287
288
289
290
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
291
292
293
294
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
295
296
297
298
299
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
300
301
}

302
_MULTIMODAL_MODELS = {
303
    # [Decoder-only]
304
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
305
306
307
308
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
309
310
311
312
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
313
314
315
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
316
    ),
317
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
318
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
319
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
320
321
322
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
323
    ),
324
325
326
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
327
    ),
328
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
329
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
330
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
331
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
332
333
334
335
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
336
337
338
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
339
    ),
340
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
341
342
343
344
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
345
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
346
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
347
348
349
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
350
    ),
351
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
352
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
353
354
355
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
356
357
358
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
359
    ),
360
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
361
362
363
364
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
365
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
366
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
367
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
368
369
370
371
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
372
373
374
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
375
    ),
376
377
378
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
379
    ),
zxy's avatar
zxy committed
380
381
382
383
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
384
385
386
387
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
388
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
389
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
390
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
391
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
392
393
394
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
395
    ),
396
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
397
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
398
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
399
400
401
402
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
403
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
404
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
405
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
406
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
407
408
409
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
410
    ),
411
412
413
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
414
    ),
415
416
417
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
418
    ),
419
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
420
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
421
422
423
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
424
    ),
425
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
426
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
427
428
429
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
430
    ),
431
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
432
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
433
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
434
435
436
437
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
438
    "Ovis": ("ovis", "Ovis"),
439
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
440
441
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
442
443
444
445
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
446
447
448
449
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
450
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
451
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
452
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
453
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
454
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
455
456
457
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
458
    ),
459
460
461
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
462
    ),
463
464
465
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
466
    ),
467
468
469
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
470
    ),
471
472
473
474
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
475
476
477
478
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
479
480
481
482
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
483
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
484
485
486
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
487
    ),
488
489
490
491
492
493
494
495
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
496
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
497
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
498
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
499
500
501
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
502
    ),
503
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
504
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
505
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
506
    # [Encoder-decoder]
507
508
509
510
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
511
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
512
}
513
514

_SPECULATIVE_DECODING_MODELS = {
515
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
516
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
517
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
518
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
519
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
520
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
521
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
522
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
523
524
525
526
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
527
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
528
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
529
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
530
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
531
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
532
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
533
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
534
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
535
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
536
    "MedusaModel": ("medusa", "Medusa"),
537
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
538
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
539
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
540
541
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
542
543
544
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
545
}
546

547
_TRANSFORMERS_SUPPORTED_MODELS = {
548
549
550
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
551
552
553
554
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
555
556
557
}

_TRANSFORMERS_BACKEND_MODELS = {
558
    # Text generation models
559
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
578
    "TransformersForSequenceClassification": (
579
        "transformers",
580
        "TransformersForSequenceClassification",
581
    ),
582
    "TransformersMoEForSequenceClassification": (
583
        "transformers",
584
        "TransformersMoEForSequenceClassification",
585
    ),
586
587
588
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
589
    ),
590
}
591

592
_VLLM_MODELS = {
593
    **_TEXT_GENERATION_MODELS,
594
    **_EMBEDDING_MODELS,
595
    **_CROSS_ENCODER_MODELS,
596
    **_MULTIMODAL_MODELS,
597
    **_SPECULATIVE_DECODING_MODELS,
598
599
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
600
601
}

602
603
604
605
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
606
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
607

608
_PREVIOUSLY_SUPPORTED_MODELS = {
609
    "MotifForCausalLM": "0.10.2",
610
    "Phi3SmallForCausalLM": "0.9.2",
611
    "Phi4FlashForCausalLM": "0.10.2",
612
    "Phi4MultimodalForCausalLM": "0.12.0",
613
614
615
616
617
618
619
620
621
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
622

623

624
625
@dataclass(frozen=True)
class _ModelInfo:
626
    architecture: str
627
    is_text_generation_model: bool
628
    is_pooling_model: bool
629
    attn_type: AttnTypeStr
630
631
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
632
    supports_cross_encoding: bool
633
    supports_late_interaction: bool
634
    supports_multimodal: bool
635
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
636
    requires_raw_input_tokens: bool
637
    supports_multimodal_encoder_tp_data: bool
638
    supports_pp: bool
639
640
    has_inner_state: bool
    is_attention_free: bool
641
    is_hybrid: bool
642
    has_noops: bool
643
    supports_mamba_prefix_caching: bool
644
    supports_transcription: bool
645
    supports_transcription_only: bool
646
647

    @staticmethod
648
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
649
        return _ModelInfo(
650
            architecture=model.__name__,
651
            is_text_generation_model=is_text_generation_model(model),
652
            is_pooling_model=is_pooling_model(model),
653
654
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
655
            attn_type=get_attn_type(model),
656
            supports_cross_encoding=supports_cross_encoding(model),
657
            supports_late_interaction=supports_late_interaction(model),
658
            supports_multimodal=supports_multimodal(model),
659
660
661
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
662
            requires_raw_input_tokens=requires_raw_input_tokens(model),
663
664
665
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
666
            supports_pp=supports_pp(model),
667
668
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
669
            is_hybrid=is_hybrid(model),
670
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
671
            supports_transcription=supports_transcription(model),
672
673
674
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
675
            has_noops=has_noops(model),
676
        )
677
678


679
680
681
682
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
683

684
    @abstractmethod
685
    def load_model_cls(self) -> type[nn.Module]:
686
        raise NotImplementedError
687
688


689
690
691
692
693
694
695
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
696
    model_cls: type[nn.Module]
697
698

    @staticmethod
699
    def from_model_cls(model_cls: type[nn.Module]):
700
701
702
703
704
705
706
707
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

708
    def load_model_cls(self) -> type[nn.Module]:
709
710
711
712
713
714
715
716
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
717

718
719
720
    module_name: str
    class_name: str

721
722
723
724
725
726
727
728
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

729
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
730
731
        try:
            try:
732
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
733
734
735
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
736
                logger.debug(
737
                    "Cached model info file for class %s.%s not found",
738
739
740
                    self.module_name,
                    self.class_name,
                )
741
742
743
                return None

            if mi_dict["hash"] != module_hash:
744
                logger.debug(
745
                    "Cached model info file for class %s.%s is stale",
746
747
748
                    self.module_name,
                    self.class_name,
                )
749
750
751
752
753
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
754
            logger.debug(
755
                "Cached model info for class %s.%s error. ",
756
757
758
                self.module_name,
                self.class_name,
            )
759
760
            return None

761
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
762
763
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
764

765
766
767
768
769
770
771
772
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
773
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
774
775
776
777
778
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
779
    def inspect_model_cls(self) -> _ModelInfo:
780
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
781
        module_hash = None
782

783
784
        if model_path.exists():
            with open(model_path, "rb") as f:
785
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
786
787
788

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
789
                logger.debug(
790
                    "Loaded model info for class %s.%s from cache",
791
792
793
                    self.module_name,
                    self.class_name,
                )
794
795
                return mi
            else:
796
                logger.debug(
797
                    "Cache model info for class %s.%s miss. Loading model instead.",
798
799
800
                    self.module_name,
                    self.class_name,
                )
801
802
803

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
804
805
806
807
808
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
809
810

        # save cache file
811
812
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
813
814

        return mi
815

816
    def load_model_cls(self) -> type[nn.Module]:
817
818
819
820
821
822
823
824
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
825
) -> type[nn.Module] | None:
826
    from vllm.platforms import current_platform
827

828
    current_platform.verify_model_arch(model_arch)
829
830
831
    try:
        return model.load_model_cls()
    except Exception:
832
        logger.exception("Error in loading model architecture '%s'", model_arch)
833
        return None
834
835


836
837
838
839
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
840
) -> _ModelInfo | None:
841
842
843
    try:
        return model.inspect_model_cls()
    except Exception:
844
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
845
        return None
846
847


848
849
850
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
851
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
852

853
    def get_supported_archs(self) -> Set[str]:
854
        return self.models.keys()
855

856
857
858
    def register_model(
        self,
        model_arch: str,
859
        model_cls: type[nn.Module] | str,
860
    ) -> None:
861
862
863
        """
        Register an external model to be used in vLLM.

864
        `model_cls` can be either:
865

866
        - A [`torch.nn.Module`][] class directly referencing the model.
867
        - A string in the format `<module>:<class>` which can be used to
868
869
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
870
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
871
        """
872
873
874
875
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

876
        if model_arch in self.models:
877
878
            logger.warning(
                "Model architecture %s is already registered, and will be "
879
880
881
882
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
883
884
885
886
887
888

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
889

890
            model = _LazyRegisteredModel(*split_str)
891
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
892
            model = _RegisteredModel.from_model_cls(model_cls)
893
        else:
894
895
896
897
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
898
            raise TypeError(msg)
899

900
        self.models[model_arch] = model
901

902
    def _raise_for_unsupported(self, architectures: list[str]):
903
        all_supported_archs = self.get_supported_archs()
904

905
906
907
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
908
909
                "to be inspected. Please check the logs for more details."
            )
910

911
912
913
914
915
916
917
918
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
919
920
                    "use this model architecture."
                )
921

922
923
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
924
925
            f"Supported architectures: {all_supported_archs}"
        )
926

927
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
928
929
        if model_arch not in self.models:
            return None
930

931
        return _try_load_model_cls(model_arch, self.models[model_arch])
932

933
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
934
935
        if model_arch not in self.models:
            return None
936

937
938
939
940
941
942
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
943
    ) -> str | None:
944
945
946
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

947
948
949
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
966
                        trust_remote_code=model_config.trust_remote_code,
967
968
969
970
971
972
973
974
975
976
977
978
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
979
                        trust_remote_code=model_config.trust_remote_code,
980
981
982
983
984
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
985
                if model_config.model_impl != "transformers":
986
987
988
989
990
991
992
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
993
994
                    "'auto_map' (relevant if the model is custom)."
                )
995
996

        if not model_module.is_backend_compatible():
997
            if model_config.model_impl != "transformers":
998
                return None
999

1000
1001
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1002
1003
                "is not compatible with vLLM."
            )
1004

1005
        return model_config._get_transformers_backend_cls()
1006

1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1032

1033
1034
    def inspect_model_cls(
        self,
1035
        architectures: str | list[str],
1036
        model_config: ModelConfig,
1037
    ) -> tuple[_ModelInfo, str]:
1038
1039
        if isinstance(architectures, str):
            architectures = [architectures]
1040
1041
        if not architectures:
            raise ValueError("No model architectures are specified")
1042
1043

        # Require transformers impl
1044
        if model_config.model_impl == "transformers":
1045
            arch = self._try_resolve_transformers(architectures[0], model_config)
1046
1047
1048
1049
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1050
        elif model_config.model_impl == "terratorch":
1051
1052
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1053

1054
        # Fallback to transformers impl (after resolving convert_type)
1055
1056
1057
1058
1059
1060
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1061
1062
1063
1064
1065
1066
1067
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1068
            model_info = self._try_inspect_model_cls(normalized_arch)
1069
            if model_info is not None:
1070
                return (model_info, arch)
1071

1072
        # Fallback to transformers impl (before resolving runner_type)
1073
1074
1075
1076
1077
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1078
1079
1080
1081
1082
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1083
        return self._raise_for_unsupported(architectures)
1084

1085
1086
    def resolve_model_cls(
        self,
1087
        architectures: str | list[str],
1088
        model_config: ModelConfig,
1089
    ) -> tuple[type[nn.Module], str]:
1090
1091
        if isinstance(architectures, str):
            architectures = [architectures]
1092
1093
        if not architectures:
            raise ValueError("No model architectures are specified")
1094
1095

        # Require transformers impl
1096
        if model_config.model_impl == "transformers":
1097
            arch = self._try_resolve_transformers(architectures[0], model_config)
1098
1099
1100
1101
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1102
        elif model_config.model_impl == "terratorch":
1103
1104
1105
1106
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1107

1108
        # Fallback to transformers impl (after resolving convert_type)
1109
1110
1111
1112
1113
1114
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1115
1116
1117
1118
1119
1120
1121
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1122
            model_cls = self._try_load_model_cls(normalized_arch)
1123
1124
            if model_cls is not None:
                return (model_cls, arch)
1125

1126
        # Fallback to transformers impl (before resolving runner_type)
1127
1128
1129
1130
1131
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1132
1133
1134
1135
1136
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1137
        return self._raise_for_unsupported(architectures)
1138

1139
1140
    def is_text_generation_model(
        self,
1141
        architectures: str | list[str],
1142
        model_config: ModelConfig,
1143
    ) -> bool:
1144
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1145
        return model_cls.is_text_generation_model
1146

1147
    def is_pooling_model(
1148
        self,
1149
        architectures: str | list[str],
1150
        model_config: ModelConfig,
1151
    ) -> bool:
1152
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1153
        return model_cls.is_pooling_model
1154

1155
1156
    def is_cross_encoder_model(
        self,
1157
        architectures: str | list[str],
1158
        model_config: ModelConfig,
1159
    ) -> bool:
1160
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1161
        return model_cls.supports_cross_encoding
1162

1163
1164
    def is_multimodal_model(
        self,
1165
        architectures: str | list[str],
1166
        model_config: ModelConfig,
1167
    ) -> bool:
1168
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1169
        return model_cls.supports_multimodal
1170

1171
    def is_multimodal_raw_input_only_model(
1172
        self,
1173
        architectures: str | list[str],
1174
        model_config: ModelConfig,
1175
    ) -> bool:
1176
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1177
        return model_cls.supports_multimodal_raw_input_only
1178

1179
1180
    def is_pp_supported_model(
        self,
1181
        architectures: str | list[str],
1182
        model_config: ModelConfig,
1183
    ) -> bool:
1184
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1185
        return model_cls.supports_pp
1186

1187
1188
    def model_has_inner_state(
        self,
1189
        architectures: str | list[str],
1190
        model_config: ModelConfig,
1191
    ) -> bool:
1192
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1193
        return model_cls.has_inner_state
1194

1195
1196
    def is_attention_free_model(
        self,
1197
        architectures: str | list[str],
1198
        model_config: ModelConfig,
1199
    ) -> bool:
1200
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1201
        return model_cls.is_attention_free
1202

1203
1204
    def is_hybrid_model(
        self,
1205
        architectures: str | list[str],
1206
        model_config: ModelConfig,
1207
    ) -> bool:
1208
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1209
1210
        return model_cls.is_hybrid

1211
1212
    def is_noops_model(
        self,
1213
        architectures: str | list[str],
1214
        model_config: ModelConfig,
1215
    ) -> bool:
1216
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1217
1218
        return model_cls.has_noops

1219
1220
    def is_transcription_model(
        self,
1221
        architectures: str | list[str],
1222
        model_config: ModelConfig,
1223
    ) -> bool:
1224
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1225
1226
        return model_cls.supports_transcription

1227
1228
    def is_transcription_only_model(
        self,
1229
        architectures: str | list[str],
1230
        model_config: ModelConfig,
1231
    ) -> bool:
1232
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1233
1234
        return model_cls.supports_transcription_only

1235

1236
1237
1238
1239
1240
1241
1242
1243
1244
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1245
1246
1247
1248
1249

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1250
1251
1252
1253
1254
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1255
        # `cloudpickle` allows pickling lambda functions directly
1256
        import cloudpickle
1257

1258
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1259
1260
1261

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1262
1263
1264
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1265
1266
1267
1268
1269
1270

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1271
1272
1273
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1274

1275
        with open(output_filepath, "rb") as f:
1276
1277
1278
1279
1280
1281
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1282

1283
1284
1285
1286
1287
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1288
1289
1290

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1291
1292
1293


if __name__ == "__main__":
1294
    _run()