registry.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
38
39
40
41
42
43
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
else:
    AttnTypeStr = Any
    PoolingTypeStr = Any


44
45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
50
    supports_mamba_prefix_caching,
51
52
53
54
55
56
57
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
58
    get_attn_type,
59
60
61
62
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
63
64
65

logger = init_logger(__name__)

66
67
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
68
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
69
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
70
71
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
72
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
73
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
74
75
76
77
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
78
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
79
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
80
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
81
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
82
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
83
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
84
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
85
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
86
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
87
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
88
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
89
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
90
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
91
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
92
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
93
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
94
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
95
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
96
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
97
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
98
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
99
100
101
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
102
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
103
104
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
105
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
106
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
107
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
108
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
109
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
110
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
111
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
112
113
114
115
116
117
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
118
119
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
120
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
121
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
122
123
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
124
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
125
126
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
127
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
128
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
129
130
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
131
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
132
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
133
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
134
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
135
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
136
137
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
138
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
139
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
140
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
141
142
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
143
144
145
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
146
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
147
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
149
150
151
152
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
153
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
154
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
155
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
156
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
157
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
158
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
159
160
161
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
162
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
163
164
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
165
166
167
168
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
169
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
170
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
171
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
172
173
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
174
175
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
176
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
177
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
178
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
179
180
181
182
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
183
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
184
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
185
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
186
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
187
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
188
189
190
}

_EMBEDDING_MODELS = {
191
    # [Text-only]
192
    "BertModel": ("bert", "BertEmbeddingModel"),
193
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
194
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
195
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
196
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
197
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
198
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
199
    "GritLM": ("gritlm", "GritLM"),
200
201
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
202
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
203
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
204
    "LlamaModel": ("llama", "LlamaForCausalLM"),
205
206
    **{
        # Multiple models share the same architecture, so we include them all
207
208
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
209
210
        if arch == "LlamaForCausalLM"
    },
211
    "MistralModel": ("llama", "LlamaForCausalLM"),
212
    "ModernBertModel": ("modernbert", "ModernBertModel"),
213
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
214
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
215
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
216
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
217
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
218
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
219
220
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
221
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
222
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
223
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
224
    # [Multimodal]
225
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
226
227
228
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
229
    ),
230
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
231
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
232
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
233
234
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
235
    # models for the time being.
236
237
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
238
239
}

240
241
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
242
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
243
244
245
246
247
248
249
250
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
251
252
253
254
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
255
256
257
258
259
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
260
    # [Auto-converted (see adapters.py)]
261
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
262
263
}

264
_MULTIMODAL_MODELS = {
265
    # [Decoder-only]
266
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
267
268
269
270
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
271
272
273
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
274
    ),
275
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
276
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
277
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
278
279
280
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
281
    ),
282
283
284
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
285
    ),
286
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
287
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
288
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
289
290
291
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
292
    ),
293
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
294
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
295
296
297
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
298
    ),
299
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
300
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
301
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
302
303
304
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
305
    ),
306
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
307
308
309
310
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
311
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
312
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
313
314
315
316
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
317
318
319
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
320
    ),
321
322
323
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
324
    ),
325
326
327
328
329
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
330
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
331
332
333
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
334
    ),
335
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
336
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
337
338
339
340
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
341
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
342
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
343
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
344
345
346
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
347
    ),
348
349
350
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
351
    ),
352
353
354
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
355
    ),
356
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
357
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
358
359
360
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
361
    ),
362
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
363
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
364
365
366
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
367
    ),
368
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
369
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
370
    "Ovis": ("ovis", "Ovis"),
371
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
372
373
374
375
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
376
377
378
379
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
380
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
381
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
382
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
383
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
384
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
385
386
387
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
388
    ),
389
390
391
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
392
    ),
393
394
395
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
396
    ),
397
398
399
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
400
    ),
401
402
403
404
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
405
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
406
407
408
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
409
    ),
410
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
411
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
412
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
413
414
415
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
416
    ),
417
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
418
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
419
    # [Encoder-decoder]
420
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
421
}
422
423

_SPECULATIVE_DECODING_MODELS = {
424
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
425
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
426
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
427
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
428
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
429
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
430
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
431
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
432
433
434
435
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
436
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
437
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
438
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
439
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
440
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
441
    "MedusaModel": ("medusa", "Medusa"),
442
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
443
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
444
445
446
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
447
}
448

449
_TRANSFORMERS_SUPPORTED_MODELS = {
450
451
452
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
453
454
455
456
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
457
458
459
}

_TRANSFORMERS_BACKEND_MODELS = {
460
    # Text generation models
461
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
480
    "TransformersForSequenceClassification": (
481
        "transformers",
482
        "TransformersForSequenceClassification",
483
    ),
484
    "TransformersMoEForSequenceClassification": (
485
        "transformers",
486
        "TransformersMoEForSequenceClassification",
487
    ),
488
489
490
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
491
    ),
492
}
493

494
_VLLM_MODELS = {
495
    **_TEXT_GENERATION_MODELS,
496
    **_EMBEDDING_MODELS,
497
    **_CROSS_ENCODER_MODELS,
498
    **_MULTIMODAL_MODELS,
499
    **_SPECULATIVE_DECODING_MODELS,
500
501
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
502
503
}

504
505
506
507
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
508
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
509

510
_PREVIOUSLY_SUPPORTED_MODELS = {
511
    "MotifForCausalLM": "0.10.2",
512
    "Phi3SmallForCausalLM": "0.9.2",
513
    "Phi4FlashForCausalLM": "0.10.2",
514
    "Phi4MultimodalForCausalLM": "0.12.0",
515
516
517
518
519
520
521
522
523
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
524

525

526
527
@dataclass(frozen=True)
class _ModelInfo:
528
    architecture: str
529
    is_text_generation_model: bool
530
    is_pooling_model: bool
531
532
    attn_type: AttnTypeStr
    default_pooling_type: PoolingTypeStr
533
    supports_cross_encoding: bool
534
    supports_multimodal: bool
535
    supports_multimodal_raw_input_only: bool
536
    supports_multimodal_encoder_tp_data: bool
537
    supports_pp: bool
538
539
    has_inner_state: bool
    is_attention_free: bool
540
    is_hybrid: bool
541
    has_noops: bool
542
    supports_mamba_prefix_caching: bool
543
    supports_transcription: bool
544
    supports_transcription_only: bool
545
546

    @staticmethod
547
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
548
        return _ModelInfo(
549
            architecture=model.__name__,
550
            is_text_generation_model=is_text_generation_model(model),
551
            is_pooling_model=is_pooling_model(model),
552
            default_pooling_type=get_default_pooling_type(model),
553
            attn_type=get_attn_type(model),
554
            supports_cross_encoding=supports_cross_encoding(model),
555
            supports_multimodal=supports_multimodal(model),
556
557
558
559
560
561
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
562
            supports_pp=supports_pp(model),
563
564
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
565
            is_hybrid=is_hybrid(model),
566
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
567
            supports_transcription=supports_transcription(model),
568
569
570
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
571
            has_noops=has_noops(model),
572
        )
573
574


575
576
577
578
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
579

580
    @abstractmethod
581
    def load_model_cls(self) -> type[nn.Module]:
582
        raise NotImplementedError
583
584


585
586
587
588
589
590
591
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
592
    model_cls: type[nn.Module]
593
594

    @staticmethod
595
    def from_model_cls(model_cls: type[nn.Module]):
596
597
598
599
600
601
602
603
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

604
    def load_model_cls(self) -> type[nn.Module]:
605
606
607
608
609
610
611
612
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
613

614
615
616
    module_name: str
    class_name: str

617
618
619
620
621
622
623
624
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

625
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
626
627
        try:
            try:
628
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
629
630
631
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
632
                logger.debug(
633
                    "Cached model info file for class %s.%s not found",
634
635
636
                    self.module_name,
                    self.class_name,
                )
637
638
639
                return None

            if mi_dict["hash"] != module_hash:
640
                logger.debug(
641
                    "Cached model info file for class %s.%s is stale",
642
643
644
                    self.module_name,
                    self.class_name,
                )
645
646
647
648
649
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
650
            logger.debug(
651
                "Cached model info for class %s.%s error. ",
652
653
654
                self.module_name,
                self.class_name,
            )
655
656
            return None

657
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
658
659
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
660

661
662
663
664
665
666
667
668
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
669
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
670
671
672
673
674
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
675
    def inspect_model_cls(self) -> _ModelInfo:
676
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
677
        module_hash = None
678

679
680
        if model_path.exists():
            with open(model_path, "rb") as f:
681
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
682
683
684

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
685
                logger.debug(
686
                    "Loaded model info for class %s.%s from cache",
687
688
689
                    self.module_name,
                    self.class_name,
                )
690
691
                return mi
            else:
692
                logger.debug(
693
                    "Cache model info for class %s.%s miss. Loading model instead.",
694
695
696
                    self.module_name,
                    self.class_name,
                )
697
698
699

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
700
701
702
703
704
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
705
706

        # save cache file
707
708
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
709
710

        return mi
711

712
    def load_model_cls(self) -> type[nn.Module]:
713
714
715
716
717
718
719
720
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
721
) -> type[nn.Module] | None:
722
    from vllm.platforms import current_platform
723

724
    current_platform.verify_model_arch(model_arch)
725
726
727
    try:
        return model.load_model_cls()
    except Exception:
728
        logger.exception("Error in loading model architecture '%s'", model_arch)
729
        return None
730
731


732
733
734
735
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
736
) -> _ModelInfo | None:
737
738
739
    try:
        return model.inspect_model_cls()
    except Exception:
740
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
741
        return None
742
743


744
745
746
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
747
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
748

749
    def get_supported_archs(self) -> Set[str]:
750
        return self.models.keys()
751

752
753
754
    def register_model(
        self,
        model_arch: str,
755
        model_cls: type[nn.Module] | str,
756
    ) -> None:
757
758
759
        """
        Register an external model to be used in vLLM.

760
        `model_cls` can be either:
761

762
        - A [`torch.nn.Module`][] class directly referencing the model.
763
        - A string in the format `<module>:<class>` which can be used to
764
765
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
766
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
767
        """
768
769
770
771
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

772
        if model_arch in self.models:
773
774
            logger.warning(
                "Model architecture %s is already registered, and will be "
775
776
777
778
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
779
780
781
782
783
784

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
785

786
            model = _LazyRegisteredModel(*split_str)
787
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
788
            model = _RegisteredModel.from_model_cls(model_cls)
789
        else:
790
791
792
793
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
794
            raise TypeError(msg)
795

796
        self.models[model_arch] = model
797

798
    def _raise_for_unsupported(self, architectures: list[str]):
799
        all_supported_archs = self.get_supported_archs()
800

801
802
803
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
804
805
                "to be inspected. Please check the logs for more details."
            )
806

807
808
809
810
811
812
813
814
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
815
816
                    "use this model architecture."
                )
817

818
819
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
820
821
            f"Supported architectures: {all_supported_archs}"
        )
822

823
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
824
825
        if model_arch not in self.models:
            return None
826

827
        return _try_load_model_cls(model_arch, self.models[model_arch])
828

829
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
830
831
        if model_arch not in self.models:
            return None
832

833
834
835
836
837
838
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
839
    ) -> str | None:
840
841
842
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

843
844
845
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
879
                if model_config.model_impl != "transformers":
880
881
882
883
884
885
886
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
887
888
                    "'auto_map' (relevant if the model is custom)."
                )
889
890

        if not model_module.is_backend_compatible():
891
            if model_config.model_impl != "transformers":
892
                return None
893

894
895
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
896
897
                "is not compatible with vLLM."
            )
898

899
        return model_config._get_transformers_backend_cls()
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
926

927
928
    def inspect_model_cls(
        self,
929
        architectures: str | list[str],
930
        model_config: ModelConfig,
931
    ) -> tuple[_ModelInfo, str]:
932
933
        if isinstance(architectures, str):
            architectures = [architectures]
934
935
        if not architectures:
            raise ValueError("No model architectures are specified")
936
937

        # Require transformers impl
938
        if model_config.model_impl == "transformers":
939
            arch = self._try_resolve_transformers(architectures[0], model_config)
940
941
942
943
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
944
        elif model_config.model_impl == "terratorch":
945
946
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
947

948
        # Fallback to transformers impl (after resolving convert_type)
949
950
951
952
953
954
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
955
956
957
958
959
960
961
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
962
            model_info = self._try_inspect_model_cls(normalized_arch)
963
            if model_info is not None:
964
                return (model_info, arch)
965

966
        # Fallback to transformers impl (before resolving runner_type)
967
968
969
970
971
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
972
973
974
975
976
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

977
        return self._raise_for_unsupported(architectures)
978

979
980
    def resolve_model_cls(
        self,
981
        architectures: str | list[str],
982
        model_config: ModelConfig,
983
    ) -> tuple[type[nn.Module], str]:
984
985
        if isinstance(architectures, str):
            architectures = [architectures]
986
987
        if not architectures:
            raise ValueError("No model architectures are specified")
988
989

        # Require transformers impl
990
        if model_config.model_impl == "transformers":
991
            arch = self._try_resolve_transformers(architectures[0], model_config)
992
993
994
995
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
996
        elif model_config.model_impl == "terratorch":
997
998
999
1000
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1001

1002
        # Fallback to transformers impl (after resolving convert_type)
1003
1004
1005
1006
1007
1008
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1009
1010
1011
1012
1013
1014
1015
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1016
            model_cls = self._try_load_model_cls(normalized_arch)
1017
1018
            if model_cls is not None:
                return (model_cls, arch)
1019

1020
        # Fallback to transformers impl (before resolving runner_type)
1021
1022
1023
1024
1025
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1026
1027
1028
1029
1030
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1031
        return self._raise_for_unsupported(architectures)
1032

1033
1034
    def is_text_generation_model(
        self,
1035
        architectures: str | list[str],
1036
        model_config: ModelConfig,
1037
    ) -> bool:
1038
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1039
        return model_cls.is_text_generation_model
1040

1041
    def is_pooling_model(
1042
        self,
1043
        architectures: str | list[str],
1044
        model_config: ModelConfig,
1045
    ) -> bool:
1046
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1047
        return model_cls.is_pooling_model
1048

1049
1050
    def is_cross_encoder_model(
        self,
1051
        architectures: str | list[str],
1052
        model_config: ModelConfig,
1053
    ) -> bool:
1054
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1055
        return model_cls.supports_cross_encoding
1056

1057
1058
    def is_multimodal_model(
        self,
1059
        architectures: str | list[str],
1060
        model_config: ModelConfig,
1061
    ) -> bool:
1062
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1063
        return model_cls.supports_multimodal
1064

1065
    def is_multimodal_raw_input_only_model(
1066
        self,
1067
        architectures: str | list[str],
1068
        model_config: ModelConfig,
1069
    ) -> bool:
1070
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1071
        return model_cls.supports_multimodal_raw_input_only
1072

1073
1074
    def is_pp_supported_model(
        self,
1075
        architectures: str | list[str],
1076
        model_config: ModelConfig,
1077
    ) -> bool:
1078
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1079
        return model_cls.supports_pp
1080

1081
1082
    def model_has_inner_state(
        self,
1083
        architectures: str | list[str],
1084
        model_config: ModelConfig,
1085
    ) -> bool:
1086
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1087
        return model_cls.has_inner_state
1088

1089
1090
    def is_attention_free_model(
        self,
1091
        architectures: str | list[str],
1092
        model_config: ModelConfig,
1093
    ) -> bool:
1094
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1095
        return model_cls.is_attention_free
1096

1097
1098
    def is_hybrid_model(
        self,
1099
        architectures: str | list[str],
1100
        model_config: ModelConfig,
1101
    ) -> bool:
1102
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1103
1104
        return model_cls.is_hybrid

1105
1106
    def is_noops_model(
        self,
1107
        architectures: str | list[str],
1108
        model_config: ModelConfig,
1109
    ) -> bool:
1110
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1111
1112
        return model_cls.has_noops

1113
1114
    def is_transcription_model(
        self,
1115
        architectures: str | list[str],
1116
        model_config: ModelConfig,
1117
    ) -> bool:
1118
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1119
1120
        return model_cls.supports_transcription

1121
1122
    def is_transcription_only_model(
        self,
1123
        architectures: str | list[str],
1124
        model_config: ModelConfig,
1125
    ) -> bool:
1126
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1127
1128
        return model_cls.supports_transcription_only

1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1139
1140
1141
1142
1143

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1144
1145
1146
1147
1148
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1149
        # `cloudpickle` allows pickling lambda functions directly
1150
        import cloudpickle
1151

1152
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1153
1154
1155

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1156
1157
1158
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1159
1160
1161
1162
1163
1164

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1165
1166
1167
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1168

1169
        with open(output_filepath, "rb") as f:
1170
1171
1172
1173
1174
1175
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1176

1177
1178
1179
1180
1181
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1182
1183
1184

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1185
1186
1187


if __name__ == "__main__":
1188
    _run()