registry.py 44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
65
66
67
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
68
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
69
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
70
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
71
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
72
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
73
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
74
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
77
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
78
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
79
80
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
81
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
84
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
85
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
86
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
87
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
88
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
89
90
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
92
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
122
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
123
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
124
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
125
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
126
127
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
128
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
129
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
130
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
131
132
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
133
134
135
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
136
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
137
138
139
140
141
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
142
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
143
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
144
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
145
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
146
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
147
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
149
150
151
152
153
154
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
155
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
156
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
157
158
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
159
160
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
161
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
162
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
163
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
164
165
166
167
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
168
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
169
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
170
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
171
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
172
173
174
}

_EMBEDDING_MODELS = {
175
    # [Text-only]
176
    "BertModel": ("bert", "BertEmbeddingModel"),
177
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
178
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
179
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
180
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
181
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
182
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
183
    "GritLM": ("gritlm", "GritLM"),
184
185
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
186
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
187
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
188
    "LlamaModel": ("llama", "LlamaForCausalLM"),
189
190
    **{
        # Multiple models share the same architecture, so we include them all
191
192
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
193
194
        if arch == "LlamaForCausalLM"
    },
195
    "MistralModel": ("llama", "LlamaForCausalLM"),
196
    "ModernBertModel": ("modernbert", "ModernBertModel"),
197
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
198
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
199
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
200
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
201
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
202
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
203
204
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
205
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
206
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
207
    # [Multimodal]
208
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
209
210
211
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
212
    ),
213
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
214
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
215
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
216
217
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
218
    # models for the time being.
219
220
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
221
222
}

223
224
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
225
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
226
227
228
229
230
231
232
233
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
234
235
236
237
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
238
239
240
241
242
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
243
    # [Auto-converted (see adapters.py)]
244
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
245
246
}

247
_MULTIMODAL_MODELS = {
248
    # [Decoder-only]
249
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
250
251
252
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
253
    ),
254
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
255
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
256
257
258
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
259
    ),
260
261
262
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
263
    ),
264
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
265
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
266
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
267
268
269
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
270
    ),
271
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
272
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
273
274
275
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
276
    ),
277
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
278
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
279
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
280
281
282
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
283
    ),
284
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
285
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
286
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
287
288
289
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
290
    ),
291
292
293
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
294
    ),
295
296
297
298
299
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
300
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
301
302
303
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
304
    ),
305
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
306
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
307
308
309
310
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
311
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
312
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
313
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
314
315
316
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
317
    ),
318
319
320
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
321
    ),
322
323
324
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
325
    ),
326
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
327
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
328
329
330
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
331
    ),
332
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
333
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
334
335
336
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
337
    ),
338
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
339
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
340
    "Ovis": ("ovis", "Ovis"),
341
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
342
343
344
345
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
346
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
347
348
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
349
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
350
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
351
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
352
353
354
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
355
    ),
356
357
358
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
359
    ),
360
361
362
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
363
    ),
364
365
366
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
367
    ),
368
369
370
371
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
372
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
373
374
375
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
376
    ),
377
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
378
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
379
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
380
381
382
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
383
    ),
384
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
385
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
386
    # [Encoder-decoder]
387
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
388
}
389
390

_SPECULATIVE_DECODING_MODELS = {
391
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
392
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
393
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
394
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
395
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
396
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
397
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
398
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
399
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
400
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
401
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
402
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
403
    "MedusaModel": ("medusa", "Medusa"),
404
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
405
406
407
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
408
}
409

410
_TRANSFORMERS_SUPPORTED_MODELS = {
411
412
413
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
414
415
416
417
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
418
419
420
}

_TRANSFORMERS_BACKEND_MODELS = {
421
    # Text generation models
422
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
441
    "TransformersForSequenceClassification": (
442
        "transformers",
443
        "TransformersForSequenceClassification",
444
    ),
445
    "TransformersMoEForSequenceClassification": (
446
        "transformers",
447
        "TransformersMoEForSequenceClassification",
448
    ),
449
450
451
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
452
    ),
453
}
454

455
_VLLM_MODELS = {
456
    **_TEXT_GENERATION_MODELS,
457
    **_EMBEDDING_MODELS,
458
    **_CROSS_ENCODER_MODELS,
459
    **_MULTIMODAL_MODELS,
460
    **_SPECULATIVE_DECODING_MODELS,
461
462
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
463
464
}

465
466
467
468
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
469
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
470

471
_PREVIOUSLY_SUPPORTED_MODELS = {
472
    "MotifForCausalLM": "0.10.2",
473
    "Phi3SmallForCausalLM": "0.9.2",
474
    "Phi4FlashForCausalLM": "0.10.2",
475
476
477
478
479
480
481
482
483
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
484

485

486
487
@dataclass(frozen=True)
class _ModelInfo:
488
    architecture: str
489
    is_text_generation_model: bool
490
    is_pooling_model: bool
491
    default_pooling_type: str
492
    supports_cross_encoding: bool
493
    supports_multimodal: bool
494
    supports_multimodal_raw_input_only: bool
495
    supports_multimodal_encoder_tp_data: bool
496
    supports_pp: bool
497
498
    has_inner_state: bool
    is_attention_free: bool
499
    is_hybrid: bool
500
    has_noops: bool
501
    supports_mamba_prefix_caching: bool
502
    supports_transcription: bool
503
    supports_transcription_only: bool
504
505

    @staticmethod
506
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
507
        return _ModelInfo(
508
            architecture=model.__name__,
509
            is_text_generation_model=is_text_generation_model(model),
510
            is_pooling_model=is_pooling_model(model),
511
            default_pooling_type=get_default_pooling_type(model),
512
            supports_cross_encoding=supports_cross_encoding(model),
513
            supports_multimodal=supports_multimodal(model),
514
515
516
517
518
519
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
520
            supports_pp=supports_pp(model),
521
522
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
523
            is_hybrid=is_hybrid(model),
524
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
525
            supports_transcription=supports_transcription(model),
526
527
528
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
529
            has_noops=has_noops(model),
530
        )
531
532


533
534
535
536
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
537

538
    @abstractmethod
539
    def load_model_cls(self) -> type[nn.Module]:
540
        raise NotImplementedError
541
542


543
544
545
546
547
548
549
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
550
    model_cls: type[nn.Module]
551
552

    @staticmethod
553
    def from_model_cls(model_cls: type[nn.Module]):
554
555
556
557
558
559
560
561
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

562
    def load_model_cls(self) -> type[nn.Module]:
563
564
565
566
567
568
569
570
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
571

572
573
574
    module_name: str
    class_name: str

575
576
577
578
579
580
581
582
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

583
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
584
585
        try:
            try:
586
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
587
588
589
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
590
591
592
593
594
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
595
596
597
                return None

            if mi_dict["hash"] != module_hash:
598
599
600
601
602
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
603
604
605
606
607
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
608
            logger.debug(
609
610
611
612
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
613
614
            return None

615
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
616
617
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
618

619
620
621
622
623
624
625
626
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
627
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
628
629
630
631
632
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
633
    def inspect_model_cls(self) -> _ModelInfo:
634
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
635
        module_hash = None
636

637
638
        if model_path.exists():
            with open(model_path, "rb") as f:
639
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
640
641
642

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
643
644
645
646
647
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
648
649
                return mi
            else:
650
651
652
653
654
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
655
656
657

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
658
659
660
661
662
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
663
664

        # save cache file
665
666
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
667
668

        return mi
669

670
    def load_model_cls(self) -> type[nn.Module]:
671
672
673
674
675
676
677
678
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
679
) -> type[nn.Module] | None:
680
    from vllm.platforms import current_platform
681

682
    current_platform.verify_model_arch(model_arch)
683
684
685
    try:
        return model.load_model_cls()
    except Exception:
686
        logger.exception("Error in loading model architecture '%s'", model_arch)
687
        return None
688
689


690
691
692
693
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
694
) -> _ModelInfo | None:
695
696
697
    try:
        return model.inspect_model_cls()
    except Exception:
698
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
699
        return None
700
701


702
703
704
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
705
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
706

707
    def get_supported_archs(self) -> Set[str]:
708
        return self.models.keys()
709

710
711
712
    def register_model(
        self,
        model_arch: str,
713
        model_cls: type[nn.Module] | str,
714
    ) -> None:
715
716
717
        """
        Register an external model to be used in vLLM.

718
        `model_cls` can be either:
719

720
        - A [`torch.nn.Module`][] class directly referencing the model.
721
        - A string in the format `<module>:<class>` which can be used to
722
723
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
724
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
725
        """
726
727
728
729
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

730
        if model_arch in self.models:
731
732
            logger.warning(
                "Model architecture %s is already registered, and will be "
733
734
735
736
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
737
738
739
740
741
742

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
743

744
            model = _LazyRegisteredModel(*split_str)
745
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
746
            model = _RegisteredModel.from_model_cls(model_cls)
747
        else:
748
749
750
751
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
752
            raise TypeError(msg)
753

754
        self.models[model_arch] = model
755

756
    def _raise_for_unsupported(self, architectures: list[str]):
757
        all_supported_archs = self.get_supported_archs()
758

759
760
761
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
762
763
                "to be inspected. Please check the logs for more details."
            )
764

765
766
767
768
769
770
771
772
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
773
774
                    "use this model architecture."
                )
775

776
777
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
778
779
            f"Supported architectures: {all_supported_archs}"
        )
780

781
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
782
783
        if model_arch not in self.models:
            return None
784

785
        return _try_load_model_cls(model_arch, self.models[model_arch])
786

787
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
788
789
        if model_arch not in self.models:
            return None
790

791
792
793
794
795
796
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
797
    ) -> str | None:
798
799
800
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

801
802
803
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
837
                if model_config.model_impl != "transformers":
838
839
840
841
842
843
844
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
845
846
                    "'auto_map' (relevant if the model is custom)."
                )
847
848

        if not model_module.is_backend_compatible():
849
            if model_config.model_impl != "transformers":
850
                return None
851

852
853
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
854
855
                "is not compatible with vLLM."
            )
856

857
        return model_config._get_transformers_backend_cls()
858

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
884

885
886
    def inspect_model_cls(
        self,
887
        architectures: str | list[str],
888
        model_config: ModelConfig,
889
    ) -> tuple[_ModelInfo, str]:
890
891
        if isinstance(architectures, str):
            architectures = [architectures]
892
893
        if not architectures:
            raise ValueError("No model architectures are specified")
894
895

        # Require transformers impl
896
        if model_config.model_impl == "transformers":
897
            arch = self._try_resolve_transformers(architectures[0], model_config)
898
899
900
901
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
902
        elif model_config.model_impl == "terratorch":
903
904
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
905

906
        # Fallback to transformers impl (after resolving convert_type)
907
908
909
910
911
912
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
913
914
915
916
917
918
919
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
920
            model_info = self._try_inspect_model_cls(normalized_arch)
921
            if model_info is not None:
922
                return (model_info, arch)
923

924
        # Fallback to transformers impl (before resolving runner_type)
925
926
927
928
929
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
930
931
932
933
934
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

935
        return self._raise_for_unsupported(architectures)
936

937
938
    def resolve_model_cls(
        self,
939
        architectures: str | list[str],
940
        model_config: ModelConfig,
941
    ) -> tuple[type[nn.Module], str]:
942
943
        if isinstance(architectures, str):
            architectures = [architectures]
944
945
        if not architectures:
            raise ValueError("No model architectures are specified")
946
947

        # Require transformers impl
948
        if model_config.model_impl == "transformers":
949
            arch = self._try_resolve_transformers(architectures[0], model_config)
950
951
952
953
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
954
        elif model_config.model_impl == "terratorch":
955
956
957
958
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
959

960
        # Fallback to transformers impl (after resolving convert_type)
961
962
963
964
965
966
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
967
968
969
970
971
972
973
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
974
            model_cls = self._try_load_model_cls(normalized_arch)
975
976
            if model_cls is not None:
                return (model_cls, arch)
977

978
        # Fallback to transformers impl (before resolving runner_type)
979
980
981
982
983
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
984
985
986
987
988
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

989
        return self._raise_for_unsupported(architectures)
990

991
992
    def is_text_generation_model(
        self,
993
        architectures: str | list[str],
994
        model_config: ModelConfig,
995
    ) -> bool:
996
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
997
        return model_cls.is_text_generation_model
998

999
    def is_pooling_model(
1000
        self,
1001
        architectures: str | list[str],
1002
        model_config: ModelConfig,
1003
    ) -> bool:
1004
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1005
        return model_cls.is_pooling_model
1006

1007
1008
    def is_cross_encoder_model(
        self,
1009
        architectures: str | list[str],
1010
        model_config: ModelConfig,
1011
    ) -> bool:
1012
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1013
        return model_cls.supports_cross_encoding
1014

1015
1016
    def is_multimodal_model(
        self,
1017
        architectures: str | list[str],
1018
        model_config: ModelConfig,
1019
    ) -> bool:
1020
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1021
        return model_cls.supports_multimodal
1022

1023
    def is_multimodal_raw_input_only_model(
1024
        self,
1025
        architectures: str | list[str],
1026
        model_config: ModelConfig,
1027
    ) -> bool:
1028
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1029
        return model_cls.supports_multimodal_raw_input_only
1030

1031
1032
    def is_pp_supported_model(
        self,
1033
        architectures: str | list[str],
1034
        model_config: ModelConfig,
1035
    ) -> bool:
1036
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1037
        return model_cls.supports_pp
1038

1039
1040
    def model_has_inner_state(
        self,
1041
        architectures: str | list[str],
1042
        model_config: ModelConfig,
1043
    ) -> bool:
1044
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1045
        return model_cls.has_inner_state
1046

1047
1048
    def is_attention_free_model(
        self,
1049
        architectures: str | list[str],
1050
        model_config: ModelConfig,
1051
    ) -> bool:
1052
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1053
        return model_cls.is_attention_free
1054

1055
1056
    def is_hybrid_model(
        self,
1057
        architectures: str | list[str],
1058
        model_config: ModelConfig,
1059
    ) -> bool:
1060
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1061
1062
        return model_cls.is_hybrid

1063
1064
    def is_noops_model(
        self,
1065
        architectures: str | list[str],
1066
        model_config: ModelConfig,
1067
    ) -> bool:
1068
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1069
1070
        return model_cls.has_noops

1071
1072
    def is_transcription_model(
        self,
1073
        architectures: str | list[str],
1074
        model_config: ModelConfig,
1075
    ) -> bool:
1076
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1077
1078
        return model_cls.supports_transcription

1079
1080
    def is_transcription_only_model(
        self,
1081
        architectures: str | list[str],
1082
        model_config: ModelConfig,
1083
    ) -> bool:
1084
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1085
1086
        return model_cls.supports_transcription_only

1087

1088
1089
1090
1091
1092
1093
1094
1095
1096
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1097
1098
1099
1100
1101

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1102
1103
1104
1105
1106
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1107
        # `cloudpickle` allows pickling lambda functions directly
1108
        import cloudpickle
1109

1110
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1111
1112
1113

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1114
1115
1116
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1117
1118
1119
1120
1121
1122

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1123
1124
1125
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1126

1127
        with open(output_filepath, "rb") as f:
1128
1129
1130
1131
1132
1133
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1134

1135
1136
1137
1138
1139
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1140
1141
1142

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1143
1144
1145


if __name__ == "__main__":
1146
    _run()