registry.py 46.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
38
39
40
41
42
43
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
else:
    AttnTypeStr = Any
    PoolingTypeStr = Any


44
45
46
47
48
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
49
    requires_raw_input_tokens,
50
    supports_cross_encoding,
51
    supports_mamba_prefix_caching,
52
53
54
55
56
57
58
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
59
    get_attn_type,
60
61
62
63
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
64
65
66

logger = init_logger(__name__)

67
68
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
69
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
70
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
71
72
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
73
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
74
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
75
76
77
78
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
79
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
80
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
81
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
82
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
83
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
84
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
85
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
86
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
87
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
89
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
90
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
91
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
92
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
93
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
94
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
95
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
96
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
97
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
98
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
99
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
100
101
102
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
103
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
104
105
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
106
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
107
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
108
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
109
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
110
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
111
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
112
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
113
114
115
116
117
118
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
119
120
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
121
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
122
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
123
124
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
125
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
126
127
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
128
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
129
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
130
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
131
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
132
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
133
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
134
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
135
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
136
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
137
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
138
139
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
140
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
141
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
142
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
143
144
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
145
146
147
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
148
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
149
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
150
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
151
152
153
154
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
155
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
156
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
157
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
158
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
159
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
160
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
161
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
162
163
164
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
165
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
166
167
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
168
169
170
171
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
172
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
173
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
174
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
175
176
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
177
178
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
179
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
180
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
181
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
182
183
184
185
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
186
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
187
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
188
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
189
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
190
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
191
192
193
}

_EMBEDDING_MODELS = {
194
    # [Text-only]
195
    "BertModel": ("bert", "BertEmbeddingModel"),
196
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
197
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
198
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
199
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
200
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
201
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
202
    "GritLM": ("gritlm", "GritLM"),
203
204
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
205
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
206
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
207
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
208
    "LlamaModel": ("llama", "LlamaForCausalLM"),
209
210
    **{
        # Multiple models share the same architecture, so we include them all
211
212
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
213
214
        if arch == "LlamaForCausalLM"
    },
215
    "MistralModel": ("llama", "LlamaForCausalLM"),
216
    "ModernBertModel": ("modernbert", "ModernBertModel"),
217
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
218
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
219
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
220
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
221
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
222
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
223
224
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
225
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
226
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
227
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
228
    # [Multimodal]
229
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
230
231
232
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
233
    ),
234
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
235
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
236
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
237
238
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
239
    # models for the time being.
240
241
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
242
243
}

244
245
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
246
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
247
248
249
250
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
251
252
253
254
255
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
256
257
258
259
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
260
261
262
263
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
264
265
266
267
268
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
269
270
}

271
_MULTIMODAL_MODELS = {
272
    # [Decoder-only]
273
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
274
275
276
277
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
278
279
280
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
281
    ),
282
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
283
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
284
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
285
286
287
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
288
    ),
289
290
291
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
292
    ),
293
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
294
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
295
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
296
297
298
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
299
    ),
300
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
301
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
302
303
304
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
305
    ),
306
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
307
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
308
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
309
310
311
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
312
    ),
313
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
314
315
316
317
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
318
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
319
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
320
321
322
323
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
324
325
326
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
327
    ),
328
329
330
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
331
    ),
332
333
334
335
336
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
337
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
338
339
340
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
341
    ),
342
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
343
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
344
345
346
347
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
348
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
349
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
350
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
351
352
353
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
354
    ),
355
356
357
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
358
    ),
359
360
361
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
362
    ),
363
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
364
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
365
366
367
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
368
    ),
369
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
370
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
371
372
373
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
374
    ),
375
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
376
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
377
    "Ovis": ("ovis", "Ovis"),
378
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
379
380
381
382
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
383
384
385
386
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
387
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
388
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
389
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
390
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
391
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
392
393
394
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
395
    ),
396
397
398
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
399
    ),
400
401
402
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
403
    ),
404
405
406
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
407
    ),
408
409
410
411
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
412
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
413
414
415
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
416
    ),
417
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
418
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
419
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
420
421
422
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
423
    ),
424
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
425
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
426
    "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"),  # noqa: E501
427
    # [Encoder-decoder]
428
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
429
}
430
431

_SPECULATIVE_DECODING_MODELS = {
432
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
433
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
434
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
435
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
436
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
437
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
438
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
439
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
440
441
442
443
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
444
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
445
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
446
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
447
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
448
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
449
    "MedusaModel": ("medusa", "Medusa"),
450
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
451
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
452
453
454
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
455
}
456

457
_TRANSFORMERS_SUPPORTED_MODELS = {
458
459
460
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
461
462
463
464
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
465
466
467
}

_TRANSFORMERS_BACKEND_MODELS = {
468
    # Text generation models
469
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
488
    "TransformersForSequenceClassification": (
489
        "transformers",
490
        "TransformersForSequenceClassification",
491
    ),
492
    "TransformersMoEForSequenceClassification": (
493
        "transformers",
494
        "TransformersMoEForSequenceClassification",
495
    ),
496
497
498
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
499
    ),
500
}
501

502
_VLLM_MODELS = {
503
    **_TEXT_GENERATION_MODELS,
504
    **_EMBEDDING_MODELS,
505
    **_CROSS_ENCODER_MODELS,
506
    **_MULTIMODAL_MODELS,
507
    **_SPECULATIVE_DECODING_MODELS,
508
509
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
510
511
}

512
513
514
515
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
516
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
517

518
_PREVIOUSLY_SUPPORTED_MODELS = {
519
    "MotifForCausalLM": "0.10.2",
520
    "Phi3SmallForCausalLM": "0.9.2",
521
    "Phi4FlashForCausalLM": "0.10.2",
522
    "Phi4MultimodalForCausalLM": "0.12.0",
523
524
525
526
527
528
529
530
531
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
532

533

534
535
@dataclass(frozen=True)
class _ModelInfo:
536
    architecture: str
537
    is_text_generation_model: bool
538
    is_pooling_model: bool
539
540
    attn_type: AttnTypeStr
    default_pooling_type: PoolingTypeStr
541
    supports_cross_encoding: bool
542
    supports_multimodal: bool
543
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
544
    requires_raw_input_tokens: bool
545
    supports_multimodal_encoder_tp_data: bool
546
    supports_pp: bool
547
548
    has_inner_state: bool
    is_attention_free: bool
549
    is_hybrid: bool
550
    has_noops: bool
551
    supports_mamba_prefix_caching: bool
552
    supports_transcription: bool
553
    supports_transcription_only: bool
554
555

    @staticmethod
556
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
557
        return _ModelInfo(
558
            architecture=model.__name__,
559
            is_text_generation_model=is_text_generation_model(model),
560
            is_pooling_model=is_pooling_model(model),
561
            default_pooling_type=get_default_pooling_type(model),
562
            attn_type=get_attn_type(model),
563
            supports_cross_encoding=supports_cross_encoding(model),
564
            supports_multimodal=supports_multimodal(model),
565
566
567
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
568
            requires_raw_input_tokens=requires_raw_input_tokens(model),
569
570
571
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
572
            supports_pp=supports_pp(model),
573
574
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
575
            is_hybrid=is_hybrid(model),
576
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
577
            supports_transcription=supports_transcription(model),
578
579
580
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
581
            has_noops=has_noops(model),
582
        )
583
584


585
586
587
588
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
589

590
    @abstractmethod
591
    def load_model_cls(self) -> type[nn.Module]:
592
        raise NotImplementedError
593
594


595
596
597
598
599
600
601
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
602
    model_cls: type[nn.Module]
603
604

    @staticmethod
605
    def from_model_cls(model_cls: type[nn.Module]):
606
607
608
609
610
611
612
613
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

614
    def load_model_cls(self) -> type[nn.Module]:
615
616
617
618
619
620
621
622
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
623

624
625
626
    module_name: str
    class_name: str

627
628
629
630
631
632
633
634
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

635
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
636
637
        try:
            try:
638
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
639
640
641
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
642
                logger.debug(
643
                    "Cached model info file for class %s.%s not found",
644
645
646
                    self.module_name,
                    self.class_name,
                )
647
648
649
                return None

            if mi_dict["hash"] != module_hash:
650
                logger.debug(
651
                    "Cached model info file for class %s.%s is stale",
652
653
654
                    self.module_name,
                    self.class_name,
                )
655
656
657
658
659
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
660
            logger.debug(
661
                "Cached model info for class %s.%s error. ",
662
663
664
                self.module_name,
                self.class_name,
            )
665
666
            return None

667
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
668
669
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
670

671
672
673
674
675
676
677
678
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
679
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
680
681
682
683
684
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
685
    def inspect_model_cls(self) -> _ModelInfo:
686
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
687
        module_hash = None
688

689
690
        if model_path.exists():
            with open(model_path, "rb") as f:
691
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
692
693
694

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
695
                logger.debug(
696
                    "Loaded model info for class %s.%s from cache",
697
698
699
                    self.module_name,
                    self.class_name,
                )
700
701
                return mi
            else:
702
                logger.debug(
703
                    "Cache model info for class %s.%s miss. Loading model instead.",
704
705
706
                    self.module_name,
                    self.class_name,
                )
707
708
709

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
710
711
712
713
714
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
715
716

        # save cache file
717
718
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
719
720

        return mi
721

722
    def load_model_cls(self) -> type[nn.Module]:
723
724
725
726
727
728
729
730
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
731
) -> type[nn.Module] | None:
732
    from vllm.platforms import current_platform
733

734
    current_platform.verify_model_arch(model_arch)
735
736
737
    try:
        return model.load_model_cls()
    except Exception:
738
        logger.exception("Error in loading model architecture '%s'", model_arch)
739
        return None
740
741


742
743
744
745
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
746
) -> _ModelInfo | None:
747
748
749
    try:
        return model.inspect_model_cls()
    except Exception:
750
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
751
        return None
752
753


754
755
756
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
757
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
758

759
    def get_supported_archs(self) -> Set[str]:
760
        return self.models.keys()
761

762
763
764
    def register_model(
        self,
        model_arch: str,
765
        model_cls: type[nn.Module] | str,
766
    ) -> None:
767
768
769
        """
        Register an external model to be used in vLLM.

770
        `model_cls` can be either:
771

772
        - A [`torch.nn.Module`][] class directly referencing the model.
773
        - A string in the format `<module>:<class>` which can be used to
774
775
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
776
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
777
        """
778
779
780
781
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

782
        if model_arch in self.models:
783
784
            logger.warning(
                "Model architecture %s is already registered, and will be "
785
786
787
788
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
789
790
791
792
793
794

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
795

796
            model = _LazyRegisteredModel(*split_str)
797
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
798
            model = _RegisteredModel.from_model_cls(model_cls)
799
        else:
800
801
802
803
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
804
            raise TypeError(msg)
805

806
        self.models[model_arch] = model
807

808
    def _raise_for_unsupported(self, architectures: list[str]):
809
        all_supported_archs = self.get_supported_archs()
810

811
812
813
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
814
815
                "to be inspected. Please check the logs for more details."
            )
816

817
818
819
820
821
822
823
824
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
825
826
                    "use this model architecture."
                )
827

828
829
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
830
831
            f"Supported architectures: {all_supported_archs}"
        )
832

833
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
834
835
        if model_arch not in self.models:
            return None
836

837
        return _try_load_model_cls(model_arch, self.models[model_arch])
838

839
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
840
841
        if model_arch not in self.models:
            return None
842

843
844
845
846
847
848
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
849
    ) -> str | None:
850
851
852
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

853
854
855
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
889
                if model_config.model_impl != "transformers":
890
891
892
893
894
895
896
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
897
898
                    "'auto_map' (relevant if the model is custom)."
                )
899
900

        if not model_module.is_backend_compatible():
901
            if model_config.model_impl != "transformers":
902
                return None
903

904
905
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
906
907
                "is not compatible with vLLM."
            )
908

909
        return model_config._get_transformers_backend_cls()
910

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
936

937
938
    def inspect_model_cls(
        self,
939
        architectures: str | list[str],
940
        model_config: ModelConfig,
941
    ) -> tuple[_ModelInfo, str]:
942
943
        if isinstance(architectures, str):
            architectures = [architectures]
944
945
        if not architectures:
            raise ValueError("No model architectures are specified")
946
947

        # Require transformers impl
948
        if model_config.model_impl == "transformers":
949
            arch = self._try_resolve_transformers(architectures[0], model_config)
950
951
952
953
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
954
        elif model_config.model_impl == "terratorch":
955
956
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
957

958
        # Fallback to transformers impl (after resolving convert_type)
959
960
961
962
963
964
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
965
966
967
968
969
970
971
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
972
            model_info = self._try_inspect_model_cls(normalized_arch)
973
            if model_info is not None:
974
                return (model_info, arch)
975

976
        # Fallback to transformers impl (before resolving runner_type)
977
978
979
980
981
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
982
983
984
985
986
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

987
        return self._raise_for_unsupported(architectures)
988

989
990
    def resolve_model_cls(
        self,
991
        architectures: str | list[str],
992
        model_config: ModelConfig,
993
    ) -> tuple[type[nn.Module], str]:
994
995
        if isinstance(architectures, str):
            architectures = [architectures]
996
997
        if not architectures:
            raise ValueError("No model architectures are specified")
998
999

        # Require transformers impl
1000
        if model_config.model_impl == "transformers":
1001
            arch = self._try_resolve_transformers(architectures[0], model_config)
1002
1003
1004
1005
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1006
        elif model_config.model_impl == "terratorch":
1007
1008
1009
1010
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1011

1012
        # Fallback to transformers impl (after resolving convert_type)
1013
1014
1015
1016
1017
1018
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1019
1020
1021
1022
1023
1024
1025
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1026
            model_cls = self._try_load_model_cls(normalized_arch)
1027
1028
            if model_cls is not None:
                return (model_cls, arch)
1029

1030
        # Fallback to transformers impl (before resolving runner_type)
1031
1032
1033
1034
1035
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1036
1037
1038
1039
1040
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1041
        return self._raise_for_unsupported(architectures)
1042

1043
1044
    def is_text_generation_model(
        self,
1045
        architectures: str | list[str],
1046
        model_config: ModelConfig,
1047
    ) -> bool:
1048
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1049
        return model_cls.is_text_generation_model
1050

1051
    def is_pooling_model(
1052
        self,
1053
        architectures: str | list[str],
1054
        model_config: ModelConfig,
1055
    ) -> bool:
1056
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1057
        return model_cls.is_pooling_model
1058

1059
1060
    def is_cross_encoder_model(
        self,
1061
        architectures: str | list[str],
1062
        model_config: ModelConfig,
1063
    ) -> bool:
1064
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1065
        return model_cls.supports_cross_encoding
1066

1067
1068
    def is_multimodal_model(
        self,
1069
        architectures: str | list[str],
1070
        model_config: ModelConfig,
1071
    ) -> bool:
1072
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1073
        return model_cls.supports_multimodal
1074

1075
    def is_multimodal_raw_input_only_model(
1076
        self,
1077
        architectures: str | list[str],
1078
        model_config: ModelConfig,
1079
    ) -> bool:
1080
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1081
        return model_cls.supports_multimodal_raw_input_only
1082

1083
1084
    def is_pp_supported_model(
        self,
1085
        architectures: str | list[str],
1086
        model_config: ModelConfig,
1087
    ) -> bool:
1088
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1089
        return model_cls.supports_pp
1090

1091
1092
    def model_has_inner_state(
        self,
1093
        architectures: str | list[str],
1094
        model_config: ModelConfig,
1095
    ) -> bool:
1096
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1097
        return model_cls.has_inner_state
1098

1099
1100
    def is_attention_free_model(
        self,
1101
        architectures: str | list[str],
1102
        model_config: ModelConfig,
1103
    ) -> bool:
1104
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1105
        return model_cls.is_attention_free
1106

1107
1108
    def is_hybrid_model(
        self,
1109
        architectures: str | list[str],
1110
        model_config: ModelConfig,
1111
    ) -> bool:
1112
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1113
1114
        return model_cls.is_hybrid

1115
1116
    def is_noops_model(
        self,
1117
        architectures: str | list[str],
1118
        model_config: ModelConfig,
1119
    ) -> bool:
1120
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1121
1122
        return model_cls.has_noops

1123
1124
    def is_transcription_model(
        self,
1125
        architectures: str | list[str],
1126
        model_config: ModelConfig,
1127
    ) -> bool:
1128
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1129
1130
        return model_cls.supports_transcription

1131
1132
    def is_transcription_only_model(
        self,
1133
        architectures: str | list[str],
1134
        model_config: ModelConfig,
1135
    ) -> bool:
1136
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1137
1138
        return model_cls.supports_transcription_only

1139

1140
1141
1142
1143
1144
1145
1146
1147
1148
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1149
1150
1151
1152
1153

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1154
1155
1156
1157
1158
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1159
        # `cloudpickle` allows pickling lambda functions directly
1160
        import cloudpickle
1161

1162
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1163
1164
1165

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1166
1167
1168
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1169
1170
1171
1172
1173
1174

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1175
1176
1177
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1178

1179
        with open(output_filepath, "rb") as f:
1180
1181
1182
1183
1184
1185
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1186

1187
1188
1189
1190
1191
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1192
1193
1194

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1195
1196
1197


if __name__ == "__main__":
1198
    _run()