registry.py 44.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35
36
37
38
39
40
41

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
60
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
61
62
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
63
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
64
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
65
66
67
68
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
69
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
70
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
71
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
72
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
73
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
74
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
75
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
77
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
78
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
79
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
80
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
81
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
82
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
84
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
85
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
86
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
87
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
88
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
89
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
90
91
92
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
93
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
94
95
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
96
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
97
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
98
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
99
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
101
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
102
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
103
104
105
106
107
108
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
109
110
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
111
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
112
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
113
114
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
115
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
116
117
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
118
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
119
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
120
121
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
122
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
123
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
124
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
125
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
126
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
127
128
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
129
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
130
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
131
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
132
133
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
134
135
136
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
137
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
138
139
140
141
142
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
143
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
144
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
145
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
146
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
147
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
152
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
153
154
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
155
156
157
158
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
159
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
160
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
161
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
162
163
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
164
165
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
166
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
167
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
168
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
169
170
171
172
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
173
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
174
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
175
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
176
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
177
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
178
179
180
}

_EMBEDDING_MODELS = {
181
    # [Text-only]
182
    "BertModel": ("bert", "BertEmbeddingModel"),
183
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
184
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
185
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
186
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
187
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
188
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
189
    "GritLM": ("gritlm", "GritLM"),
190
191
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
192
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
193
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
194
    "LlamaModel": ("llama", "LlamaForCausalLM"),
195
196
    **{
        # Multiple models share the same architecture, so we include them all
197
198
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
199
200
        if arch == "LlamaForCausalLM"
    },
201
    "MistralModel": ("llama", "LlamaForCausalLM"),
202
    "ModernBertModel": ("modernbert", "ModernBertModel"),
203
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
204
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
205
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
206
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
207
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
208
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
209
210
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
211
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
212
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
213
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
214
    # [Multimodal]
215
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
216
217
218
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
219
    ),
220
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
221
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
222
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
223
224
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
225
    # models for the time being.
226
227
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
228
229
}

230
231
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
232
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
233
234
235
236
237
238
239
240
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
241
242
243
244
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
245
246
247
248
249
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
250
    # [Auto-converted (see adapters.py)]
251
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
252
253
}

254
_MULTIMODAL_MODELS = {
255
    # [Decoder-only]
256
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
257
258
259
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
260
    ),
261
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
262
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
263
264
265
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
266
    ),
267
268
269
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
270
    ),
271
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
272
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
273
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
274
275
276
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
277
    ),
278
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
279
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
280
281
282
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
283
    ),
284
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
285
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
286
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
287
288
289
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
290
    ),
291
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
292
293
294
295
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
296
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
297
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
298
299
300
301
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
302
303
304
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
305
    ),
306
307
308
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
309
    ),
310
311
312
313
314
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
315
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
316
317
318
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
319
    ),
320
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
321
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
322
323
324
325
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
326
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
327
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
328
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
329
330
331
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
332
    ),
333
334
335
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
336
    ),
337
338
339
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
340
    ),
341
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
342
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
343
344
345
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
346
    ),
347
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
348
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
349
350
351
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
352
    ),
353
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
354
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
355
    "Ovis": ("ovis", "Ovis"),
356
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
357
358
359
360
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
361
362
363
364
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
365
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
366
367
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
368
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
369
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
370
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
371
372
373
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
374
    ),
375
376
377
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
378
    ),
379
380
381
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
382
    ),
383
384
385
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
386
    ),
387
388
389
390
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
391
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
392
393
394
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
395
    ),
396
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
397
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
398
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
399
400
401
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
402
    ),
403
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
404
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
405
    # [Encoder-decoder]
406
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
407
}
408
409

_SPECULATIVE_DECODING_MODELS = {
410
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
411
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
412
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
413
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
414
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
415
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
416
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
417
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
418
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
419
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
420
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
421
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
422
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
423
    "MedusaModel": ("medusa", "Medusa"),
424
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
425
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
426
427
428
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
429
}
430

431
_TRANSFORMERS_SUPPORTED_MODELS = {
432
433
434
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
435
436
437
438
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
439
440
441
}

_TRANSFORMERS_BACKEND_MODELS = {
442
    # Text generation models
443
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
462
    "TransformersForSequenceClassification": (
463
        "transformers",
464
        "TransformersForSequenceClassification",
465
    ),
466
    "TransformersMoEForSequenceClassification": (
467
        "transformers",
468
        "TransformersMoEForSequenceClassification",
469
    ),
470
471
472
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
473
    ),
474
}
475

476
_VLLM_MODELS = {
477
    **_TEXT_GENERATION_MODELS,
478
    **_EMBEDDING_MODELS,
479
    **_CROSS_ENCODER_MODELS,
480
    **_MULTIMODAL_MODELS,
481
    **_SPECULATIVE_DECODING_MODELS,
482
483
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
484
485
}

486
487
488
489
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
490
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
491

492
_PREVIOUSLY_SUPPORTED_MODELS = {
493
    "MotifForCausalLM": "0.10.2",
494
    "Phi3SmallForCausalLM": "0.9.2",
495
    "Phi4FlashForCausalLM": "0.10.2",
496
497
498
499
500
501
502
503
504
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
505

506

507
508
@dataclass(frozen=True)
class _ModelInfo:
509
    architecture: str
510
    is_text_generation_model: bool
511
    is_pooling_model: bool
512
    default_pooling_type: str
513
    supports_cross_encoding: bool
514
    supports_multimodal: bool
515
    supports_multimodal_raw_input_only: bool
516
    supports_multimodal_encoder_tp_data: bool
517
    supports_pp: bool
518
519
    has_inner_state: bool
    is_attention_free: bool
520
    is_hybrid: bool
521
    has_noops: bool
522
    supports_mamba_prefix_caching: bool
523
    supports_transcription: bool
524
    supports_transcription_only: bool
525
526

    @staticmethod
527
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
528
        return _ModelInfo(
529
            architecture=model.__name__,
530
            is_text_generation_model=is_text_generation_model(model),
531
            is_pooling_model=is_pooling_model(model),
532
            default_pooling_type=get_default_pooling_type(model),
533
            supports_cross_encoding=supports_cross_encoding(model),
534
            supports_multimodal=supports_multimodal(model),
535
536
537
538
539
540
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
541
            supports_pp=supports_pp(model),
542
543
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
544
            is_hybrid=is_hybrid(model),
545
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
546
            supports_transcription=supports_transcription(model),
547
548
549
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
550
            has_noops=has_noops(model),
551
        )
552
553


554
555
556
557
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
558

559
    @abstractmethod
560
    def load_model_cls(self) -> type[nn.Module]:
561
        raise NotImplementedError
562
563


564
565
566
567
568
569
570
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
571
    model_cls: type[nn.Module]
572
573

    @staticmethod
574
    def from_model_cls(model_cls: type[nn.Module]):
575
576
577
578
579
580
581
582
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

583
    def load_model_cls(self) -> type[nn.Module]:
584
585
586
587
588
589
590
591
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
592

593
594
595
    module_name: str
    class_name: str

596
597
598
599
600
601
602
603
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

604
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
605
606
        try:
            try:
607
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
608
609
610
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
611
                logger.debug(
612
                    "Cached model info file for class %s.%s not found",
613
614
615
                    self.module_name,
                    self.class_name,
                )
616
617
618
                return None

            if mi_dict["hash"] != module_hash:
619
                logger.debug(
620
                    "Cached model info file for class %s.%s is stale",
621
622
623
                    self.module_name,
                    self.class_name,
                )
624
625
626
627
628
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
629
            logger.debug(
630
                "Cached model info for class %s.%s error. ",
631
632
633
                self.module_name,
                self.class_name,
            )
634
635
            return None

636
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
637
638
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
639

640
641
642
643
644
645
646
647
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
648
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
649
650
651
652
653
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
654
    def inspect_model_cls(self) -> _ModelInfo:
655
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
656
        module_hash = None
657

658
659
        if model_path.exists():
            with open(model_path, "rb") as f:
660
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
661
662
663

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
664
                logger.debug(
665
                    "Loaded model info for class %s.%s from cache",
666
667
668
                    self.module_name,
                    self.class_name,
                )
669
670
                return mi
            else:
671
                logger.debug(
672
                    "Cache model info for class %s.%s miss. Loading model instead.",
673
674
675
                    self.module_name,
                    self.class_name,
                )
676
677
678

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
679
680
681
682
683
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
684
685

        # save cache file
686
687
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
688
689

        return mi
690

691
    def load_model_cls(self) -> type[nn.Module]:
692
693
694
695
696
697
698
699
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
700
) -> type[nn.Module] | None:
701
    from vllm.platforms import current_platform
702

703
    current_platform.verify_model_arch(model_arch)
704
705
706
    try:
        return model.load_model_cls()
    except Exception:
707
        logger.exception("Error in loading model architecture '%s'", model_arch)
708
        return None
709
710


711
712
713
714
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
715
) -> _ModelInfo | None:
716
717
718
    try:
        return model.inspect_model_cls()
    except Exception:
719
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
720
        return None
721
722


723
724
725
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
726
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
727

728
    def get_supported_archs(self) -> Set[str]:
729
        return self.models.keys()
730

731
732
733
    def register_model(
        self,
        model_arch: str,
734
        model_cls: type[nn.Module] | str,
735
    ) -> None:
736
737
738
        """
        Register an external model to be used in vLLM.

739
        `model_cls` can be either:
740

741
        - A [`torch.nn.Module`][] class directly referencing the model.
742
        - A string in the format `<module>:<class>` which can be used to
743
744
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
745
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
746
        """
747
748
749
750
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

751
        if model_arch in self.models:
752
753
            logger.warning(
                "Model architecture %s is already registered, and will be "
754
755
756
757
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
758
759
760
761
762
763

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
764

765
            model = _LazyRegisteredModel(*split_str)
766
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
767
            model = _RegisteredModel.from_model_cls(model_cls)
768
        else:
769
770
771
772
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
773
            raise TypeError(msg)
774

775
        self.models[model_arch] = model
776

777
    def _raise_for_unsupported(self, architectures: list[str]):
778
        all_supported_archs = self.get_supported_archs()
779

780
781
782
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
783
784
                "to be inspected. Please check the logs for more details."
            )
785

786
787
788
789
790
791
792
793
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
794
795
                    "use this model architecture."
                )
796

797
798
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
799
800
            f"Supported architectures: {all_supported_archs}"
        )
801

802
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
803
804
        if model_arch not in self.models:
            return None
805

806
        return _try_load_model_cls(model_arch, self.models[model_arch])
807

808
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
809
810
        if model_arch not in self.models:
            return None
811

812
813
814
815
816
817
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
818
    ) -> str | None:
819
820
821
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

822
823
824
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
858
                if model_config.model_impl != "transformers":
859
860
861
862
863
864
865
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
866
867
                    "'auto_map' (relevant if the model is custom)."
                )
868
869

        if not model_module.is_backend_compatible():
870
            if model_config.model_impl != "transformers":
871
                return None
872

873
874
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
875
876
                "is not compatible with vLLM."
            )
877

878
        return model_config._get_transformers_backend_cls()
879

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
905

906
907
    def inspect_model_cls(
        self,
908
        architectures: str | list[str],
909
        model_config: ModelConfig,
910
    ) -> tuple[_ModelInfo, str]:
911
912
        if isinstance(architectures, str):
            architectures = [architectures]
913
914
        if not architectures:
            raise ValueError("No model architectures are specified")
915
916

        # Require transformers impl
917
        if model_config.model_impl == "transformers":
918
            arch = self._try_resolve_transformers(architectures[0], model_config)
919
920
921
922
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
923
        elif model_config.model_impl == "terratorch":
924
925
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
926

927
        # Fallback to transformers impl (after resolving convert_type)
928
929
930
931
932
933
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
934
935
936
937
938
939
940
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
941
            model_info = self._try_inspect_model_cls(normalized_arch)
942
            if model_info is not None:
943
                return (model_info, arch)
944

945
        # Fallback to transformers impl (before resolving runner_type)
946
947
948
949
950
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
951
952
953
954
955
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

956
        return self._raise_for_unsupported(architectures)
957

958
959
    def resolve_model_cls(
        self,
960
        architectures: str | list[str],
961
        model_config: ModelConfig,
962
    ) -> tuple[type[nn.Module], str]:
963
964
        if isinstance(architectures, str):
            architectures = [architectures]
965
966
        if not architectures:
            raise ValueError("No model architectures are specified")
967
968

        # Require transformers impl
969
        if model_config.model_impl == "transformers":
970
            arch = self._try_resolve_transformers(architectures[0], model_config)
971
972
973
974
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
975
        elif model_config.model_impl == "terratorch":
976
977
978
979
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
980

981
        # Fallback to transformers impl (after resolving convert_type)
982
983
984
985
986
987
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
988
989
990
991
992
993
994
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
995
            model_cls = self._try_load_model_cls(normalized_arch)
996
997
            if model_cls is not None:
                return (model_cls, arch)
998

999
        # Fallback to transformers impl (before resolving runner_type)
1000
1001
1002
1003
1004
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1005
1006
1007
1008
1009
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1010
        return self._raise_for_unsupported(architectures)
1011

1012
1013
    def is_text_generation_model(
        self,
1014
        architectures: str | list[str],
1015
        model_config: ModelConfig,
1016
    ) -> bool:
1017
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1018
        return model_cls.is_text_generation_model
1019

1020
    def is_pooling_model(
1021
        self,
1022
        architectures: str | list[str],
1023
        model_config: ModelConfig,
1024
    ) -> bool:
1025
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1026
        return model_cls.is_pooling_model
1027

1028
1029
    def is_cross_encoder_model(
        self,
1030
        architectures: str | list[str],
1031
        model_config: ModelConfig,
1032
    ) -> bool:
1033
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1034
        return model_cls.supports_cross_encoding
1035

1036
1037
    def is_multimodal_model(
        self,
1038
        architectures: str | list[str],
1039
        model_config: ModelConfig,
1040
    ) -> bool:
1041
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1042
        return model_cls.supports_multimodal
1043

1044
    def is_multimodal_raw_input_only_model(
1045
        self,
1046
        architectures: str | list[str],
1047
        model_config: ModelConfig,
1048
    ) -> bool:
1049
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1050
        return model_cls.supports_multimodal_raw_input_only
1051

1052
1053
    def is_pp_supported_model(
        self,
1054
        architectures: str | list[str],
1055
        model_config: ModelConfig,
1056
    ) -> bool:
1057
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1058
        return model_cls.supports_pp
1059

1060
1061
    def model_has_inner_state(
        self,
1062
        architectures: str | list[str],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
        return model_cls.has_inner_state
1067

1068
1069
    def is_attention_free_model(
        self,
1070
        architectures: str | list[str],
1071
        model_config: ModelConfig,
1072
    ) -> bool:
1073
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1074
        return model_cls.is_attention_free
1075

1076
1077
    def is_hybrid_model(
        self,
1078
        architectures: str | list[str],
1079
        model_config: ModelConfig,
1080
    ) -> bool:
1081
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1082
1083
        return model_cls.is_hybrid

1084
1085
    def is_noops_model(
        self,
1086
        architectures: str | list[str],
1087
        model_config: ModelConfig,
1088
    ) -> bool:
1089
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1090
1091
        return model_cls.has_noops

1092
1093
    def is_transcription_model(
        self,
1094
        architectures: str | list[str],
1095
        model_config: ModelConfig,
1096
    ) -> bool:
1097
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1098
1099
        return model_cls.supports_transcription

1100
1101
    def is_transcription_only_model(
        self,
1102
        architectures: str | list[str],
1103
        model_config: ModelConfig,
1104
    ) -> bool:
1105
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1106
1107
        return model_cls.supports_transcription_only

1108

1109
1110
1111
1112
1113
1114
1115
1116
1117
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1118
1119
1120
1121
1122

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1123
1124
1125
1126
1127
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1128
        # `cloudpickle` allows pickling lambda functions directly
1129
        import cloudpickle
1130

1131
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1132
1133
1134

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1135
1136
1137
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1138
1139
1140
1141
1142
1143

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1144
1145
1146
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1147

1148
        with open(output_filepath, "rb") as f:
1149
1150
1151
1152
1153
1154
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1155

1156
1157
1158
1159
1160
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1161
1162
1163

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1164
1165
1166


if __name__ == "__main__":
1167
    _run()