registry.py 44.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
65
66
67
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
68
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
69
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
70
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
71
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
72
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
73
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
74
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
77
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
78
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
79
80
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
81
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
84
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
85
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
86
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
87
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
88
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
89
90
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
92
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
122
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
123
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
124
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
125
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
126
127
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
128
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
129
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
130
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
131
132
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
133
134
135
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
136
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
137
138
139
140
141
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
142
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
143
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
144
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
145
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
146
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
147
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
149
150
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
151
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
152
153
154
155
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
156
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
157
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
158
159
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
160
161
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
162
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
163
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
164
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
165
166
167
168
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
169
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
170
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
171
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
172
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
173
174
175
}

_EMBEDDING_MODELS = {
176
    # [Text-only]
177
    "BertModel": ("bert", "BertEmbeddingModel"),
178
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
179
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
180
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
181
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
182
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
183
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
184
    "GritLM": ("gritlm", "GritLM"),
185
186
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
187
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
188
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
189
    "LlamaModel": ("llama", "LlamaForCausalLM"),
190
191
    **{
        # Multiple models share the same architecture, so we include them all
192
193
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
194
195
        if arch == "LlamaForCausalLM"
    },
196
    "MistralModel": ("llama", "LlamaForCausalLM"),
197
    "ModernBertModel": ("modernbert", "ModernBertModel"),
198
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
199
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
200
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
201
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
202
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
203
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
204
205
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
206
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
207
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
208
    # [Multimodal]
209
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
210
211
212
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
213
    ),
214
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
215
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
216
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
217
218
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
219
    # models for the time being.
220
221
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
222
223
}

224
225
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
226
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
227
228
229
230
231
232
233
234
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
235
236
237
238
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
239
240
241
242
243
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
244
    # [Auto-converted (see adapters.py)]
245
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
246
247
}

248
_MULTIMODAL_MODELS = {
249
    # [Decoder-only]
250
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
251
252
253
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
254
    ),
255
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
256
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
257
258
259
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
260
    ),
261
262
263
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
264
    ),
265
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
266
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
267
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
268
269
270
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
271
    ),
272
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
273
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
274
275
276
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
277
    ),
278
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
279
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
280
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
281
282
283
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
284
    ),
285
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
286
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
287
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
288
289
290
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
291
    ),
292
293
294
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
295
    ),
296
297
298
299
300
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
301
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
302
303
304
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
305
    ),
306
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
307
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
308
309
310
311
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
312
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
313
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
314
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
315
316
317
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
318
    ),
319
320
321
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
322
    ),
323
324
325
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
326
    ),
327
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
328
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
329
330
331
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
332
    ),
333
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
334
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
335
336
337
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
338
    ),
339
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
340
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
341
    "Ovis": ("ovis", "Ovis"),
342
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
343
344
345
346
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
347
348
349
350
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
351
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
352
353
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
354
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
355
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
356
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
357
358
359
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
360
    ),
361
362
363
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
364
    ),
365
366
367
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
368
    ),
369
370
371
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
372
    ),
373
374
375
376
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
377
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
378
379
380
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
381
    ),
382
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
383
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
384
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
385
386
387
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
388
    ),
389
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
390
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
391
    # [Encoder-decoder]
392
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
393
}
394
395

_SPECULATIVE_DECODING_MODELS = {
396
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
397
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
398
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
399
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
400
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
401
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
402
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
403
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
404
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
405
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
406
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
407
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
408
    "MedusaModel": ("medusa", "Medusa"),
409
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
410
411
412
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
413
}
414

415
_TRANSFORMERS_SUPPORTED_MODELS = {
416
417
418
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
419
420
421
422
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
423
424
425
}

_TRANSFORMERS_BACKEND_MODELS = {
426
    # Text generation models
427
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
446
    "TransformersForSequenceClassification": (
447
        "transformers",
448
        "TransformersForSequenceClassification",
449
    ),
450
    "TransformersMoEForSequenceClassification": (
451
        "transformers",
452
        "TransformersMoEForSequenceClassification",
453
    ),
454
455
456
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
457
    ),
458
}
459

460
_VLLM_MODELS = {
461
    **_TEXT_GENERATION_MODELS,
462
    **_EMBEDDING_MODELS,
463
    **_CROSS_ENCODER_MODELS,
464
    **_MULTIMODAL_MODELS,
465
    **_SPECULATIVE_DECODING_MODELS,
466
467
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
468
469
}

470
471
472
473
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
474
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
475

476
_PREVIOUSLY_SUPPORTED_MODELS = {
477
    "MotifForCausalLM": "0.10.2",
478
    "Phi3SmallForCausalLM": "0.9.2",
479
    "Phi4FlashForCausalLM": "0.10.2",
480
481
482
483
484
485
486
487
488
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
489

490

491
492
@dataclass(frozen=True)
class _ModelInfo:
493
    architecture: str
494
    is_text_generation_model: bool
495
    is_pooling_model: bool
496
    default_pooling_type: str
497
    supports_cross_encoding: bool
498
    supports_multimodal: bool
499
    supports_multimodal_raw_input_only: bool
500
    supports_multimodal_encoder_tp_data: bool
501
    supports_pp: bool
502
503
    has_inner_state: bool
    is_attention_free: bool
504
    is_hybrid: bool
505
    has_noops: bool
506
    supports_mamba_prefix_caching: bool
507
    supports_transcription: bool
508
    supports_transcription_only: bool
509
510

    @staticmethod
511
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
512
        return _ModelInfo(
513
            architecture=model.__name__,
514
            is_text_generation_model=is_text_generation_model(model),
515
            is_pooling_model=is_pooling_model(model),
516
            default_pooling_type=get_default_pooling_type(model),
517
            supports_cross_encoding=supports_cross_encoding(model),
518
            supports_multimodal=supports_multimodal(model),
519
520
521
522
523
524
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
525
            supports_pp=supports_pp(model),
526
527
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
528
            is_hybrid=is_hybrid(model),
529
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
530
            supports_transcription=supports_transcription(model),
531
532
533
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
534
            has_noops=has_noops(model),
535
        )
536
537


538
539
540
541
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
542

543
    @abstractmethod
544
    def load_model_cls(self) -> type[nn.Module]:
545
        raise NotImplementedError
546
547


548
549
550
551
552
553
554
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
555
    model_cls: type[nn.Module]
556
557

    @staticmethod
558
    def from_model_cls(model_cls: type[nn.Module]):
559
560
561
562
563
564
565
566
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

567
    def load_model_cls(self) -> type[nn.Module]:
568
569
570
571
572
573
574
575
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
576

577
578
579
    module_name: str
    class_name: str

580
581
582
583
584
585
586
587
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

588
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
589
590
        try:
            try:
591
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
592
593
594
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
595
596
597
598
599
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
600
601
602
                return None

            if mi_dict["hash"] != module_hash:
603
604
605
606
607
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
608
609
610
611
612
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
613
            logger.debug(
614
615
616
617
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
618
619
            return None

620
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
621
622
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
623

624
625
626
627
628
629
630
631
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
632
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
633
634
635
636
637
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
638
    def inspect_model_cls(self) -> _ModelInfo:
639
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
640
        module_hash = None
641

642
643
        if model_path.exists():
            with open(model_path, "rb") as f:
644
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
645
646
647

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
648
649
650
651
652
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
653
654
                return mi
            else:
655
656
657
658
659
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
660
661
662

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
663
664
665
666
667
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
668
669

        # save cache file
670
671
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
672
673

        return mi
674

675
    def load_model_cls(self) -> type[nn.Module]:
676
677
678
679
680
681
682
683
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
684
) -> type[nn.Module] | None:
685
    from vllm.platforms import current_platform
686

687
    current_platform.verify_model_arch(model_arch)
688
689
690
    try:
        return model.load_model_cls()
    except Exception:
691
        logger.exception("Error in loading model architecture '%s'", model_arch)
692
        return None
693
694


695
696
697
698
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
699
) -> _ModelInfo | None:
700
701
702
    try:
        return model.inspect_model_cls()
    except Exception:
703
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
704
        return None
705
706


707
708
709
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
710
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
711

712
    def get_supported_archs(self) -> Set[str]:
713
        return self.models.keys()
714

715
716
717
    def register_model(
        self,
        model_arch: str,
718
        model_cls: type[nn.Module] | str,
719
    ) -> None:
720
721
722
        """
        Register an external model to be used in vLLM.

723
        `model_cls` can be either:
724

725
        - A [`torch.nn.Module`][] class directly referencing the model.
726
        - A string in the format `<module>:<class>` which can be used to
727
728
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
729
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
730
        """
731
732
733
734
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

735
        if model_arch in self.models:
736
737
            logger.warning(
                "Model architecture %s is already registered, and will be "
738
739
740
741
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
742
743
744
745
746
747

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
748

749
            model = _LazyRegisteredModel(*split_str)
750
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
751
            model = _RegisteredModel.from_model_cls(model_cls)
752
        else:
753
754
755
756
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
757
            raise TypeError(msg)
758

759
        self.models[model_arch] = model
760

761
    def _raise_for_unsupported(self, architectures: list[str]):
762
        all_supported_archs = self.get_supported_archs()
763

764
765
766
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
767
768
                "to be inspected. Please check the logs for more details."
            )
769

770
771
772
773
774
775
776
777
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
778
779
                    "use this model architecture."
                )
780

781
782
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
783
784
            f"Supported architectures: {all_supported_archs}"
        )
785

786
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
787
788
        if model_arch not in self.models:
            return None
789

790
        return _try_load_model_cls(model_arch, self.models[model_arch])
791

792
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
793
794
        if model_arch not in self.models:
            return None
795

796
797
798
799
800
801
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
802
    ) -> str | None:
803
804
805
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

806
807
808
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
842
                if model_config.model_impl != "transformers":
843
844
845
846
847
848
849
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
850
851
                    "'auto_map' (relevant if the model is custom)."
                )
852
853

        if not model_module.is_backend_compatible():
854
            if model_config.model_impl != "transformers":
855
                return None
856

857
858
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
859
860
                "is not compatible with vLLM."
            )
861

862
        return model_config._get_transformers_backend_cls()
863

864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
889

890
891
    def inspect_model_cls(
        self,
892
        architectures: str | list[str],
893
        model_config: ModelConfig,
894
    ) -> tuple[_ModelInfo, str]:
895
896
        if isinstance(architectures, str):
            architectures = [architectures]
897
898
        if not architectures:
            raise ValueError("No model architectures are specified")
899
900

        # Require transformers impl
901
        if model_config.model_impl == "transformers":
902
            arch = self._try_resolve_transformers(architectures[0], model_config)
903
904
905
906
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
907
        elif model_config.model_impl == "terratorch":
908
909
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
910

911
        # Fallback to transformers impl (after resolving convert_type)
912
913
914
915
916
917
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
918
919
920
921
922
923
924
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
925
            model_info = self._try_inspect_model_cls(normalized_arch)
926
            if model_info is not None:
927
                return (model_info, arch)
928

929
        # Fallback to transformers impl (before resolving runner_type)
930
931
932
933
934
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
935
936
937
938
939
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

940
        return self._raise_for_unsupported(architectures)
941

942
943
    def resolve_model_cls(
        self,
944
        architectures: str | list[str],
945
        model_config: ModelConfig,
946
    ) -> tuple[type[nn.Module], str]:
947
948
        if isinstance(architectures, str):
            architectures = [architectures]
949
950
        if not architectures:
            raise ValueError("No model architectures are specified")
951
952

        # Require transformers impl
953
        if model_config.model_impl == "transformers":
954
            arch = self._try_resolve_transformers(architectures[0], model_config)
955
956
957
958
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
959
        elif model_config.model_impl == "terratorch":
960
961
962
963
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
964

965
        # Fallback to transformers impl (after resolving convert_type)
966
967
968
969
970
971
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
972
973
974
975
976
977
978
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
979
            model_cls = self._try_load_model_cls(normalized_arch)
980
981
            if model_cls is not None:
                return (model_cls, arch)
982

983
        # Fallback to transformers impl (before resolving runner_type)
984
985
986
987
988
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
989
990
991
992
993
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

994
        return self._raise_for_unsupported(architectures)
995

996
997
    def is_text_generation_model(
        self,
998
        architectures: str | list[str],
999
        model_config: ModelConfig,
1000
    ) -> bool:
1001
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1002
        return model_cls.is_text_generation_model
1003

1004
    def is_pooling_model(
1005
        self,
1006
        architectures: str | list[str],
1007
        model_config: ModelConfig,
1008
    ) -> bool:
1009
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1010
        return model_cls.is_pooling_model
1011

1012
1013
    def is_cross_encoder_model(
        self,
1014
        architectures: str | list[str],
1015
        model_config: ModelConfig,
1016
    ) -> bool:
1017
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1018
        return model_cls.supports_cross_encoding
1019

1020
1021
    def is_multimodal_model(
        self,
1022
        architectures: str | list[str],
1023
        model_config: ModelConfig,
1024
    ) -> bool:
1025
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1026
        return model_cls.supports_multimodal
1027

1028
    def is_multimodal_raw_input_only_model(
1029
        self,
1030
        architectures: str | list[str],
1031
        model_config: ModelConfig,
1032
    ) -> bool:
1033
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1034
        return model_cls.supports_multimodal_raw_input_only
1035

1036
1037
    def is_pp_supported_model(
        self,
1038
        architectures: str | list[str],
1039
        model_config: ModelConfig,
1040
    ) -> bool:
1041
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1042
        return model_cls.supports_pp
1043

1044
1045
    def model_has_inner_state(
        self,
1046
        architectures: str | list[str],
1047
        model_config: ModelConfig,
1048
    ) -> bool:
1049
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1050
        return model_cls.has_inner_state
1051

1052
1053
    def is_attention_free_model(
        self,
1054
        architectures: str | list[str],
1055
        model_config: ModelConfig,
1056
    ) -> bool:
1057
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1058
        return model_cls.is_attention_free
1059

1060
1061
    def is_hybrid_model(
        self,
1062
        architectures: str | list[str],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
1067
        return model_cls.is_hybrid

1068
1069
    def is_noops_model(
        self,
1070
        architectures: str | list[str],
1071
        model_config: ModelConfig,
1072
    ) -> bool:
1073
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1074
1075
        return model_cls.has_noops

1076
1077
    def is_transcription_model(
        self,
1078
        architectures: str | list[str],
1079
        model_config: ModelConfig,
1080
    ) -> bool:
1081
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1082
1083
        return model_cls.supports_transcription

1084
1085
    def is_transcription_only_model(
        self,
1086
        architectures: str | list[str],
1087
        model_config: ModelConfig,
1088
    ) -> bool:
1089
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1090
1091
        return model_cls.supports_transcription_only

1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1102
1103
1104
1105
1106

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1107
1108
1109
1110
1111
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1112
        # `cloudpickle` allows pickling lambda functions directly
1113
        import cloudpickle
1114

1115
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1116
1117
1118

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1119
1120
1121
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1122
1123
1124
1125
1126
1127

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1128
1129
1130
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1131

1132
        with open(output_filepath, "rb") as f:
1133
1134
1135
1136
1137
1138
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1139

1140
1141
1142
1143
1144
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1145
1146
1147

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1148
1149
1150


if __name__ == "__main__":
1151
    _run()