registry.py 44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
65
66
67
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
68
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
69
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
70
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
71
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
72
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
73
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
74
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
77
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
78
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
79
80
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
81
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
84
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
85
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
86
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
87
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
88
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
89
90
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
92
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
122
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
123
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
124
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
125
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
126
127
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
128
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
129
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
130
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
131
132
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
133
134
135
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
136
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
137
138
139
140
141
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
142
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
143
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
144
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
145
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
146
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
147
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
149
150
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
151
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
152
153
154
155
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
156
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
157
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
158
159
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
160
161
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
162
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
163
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
164
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
165
166
167
168
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
169
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
170
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
171
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
172
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
173
174
175
}

_EMBEDDING_MODELS = {
176
    # [Text-only]
177
    "BertModel": ("bert", "BertEmbeddingModel"),
178
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
179
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
180
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
181
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
182
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
183
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
184
    "GritLM": ("gritlm", "GritLM"),
185
186
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
187
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
188
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
189
    "LlamaModel": ("llama", "LlamaForCausalLM"),
190
191
    **{
        # Multiple models share the same architecture, so we include them all
192
193
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
194
195
        if arch == "LlamaForCausalLM"
    },
196
    "MistralModel": ("llama", "LlamaForCausalLM"),
197
    "ModernBertModel": ("modernbert", "ModernBertModel"),
198
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
199
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
200
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
201
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
202
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
203
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
204
205
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
206
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
207
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
208
    # [Multimodal]
209
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
210
211
212
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
213
    ),
214
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
215
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
216
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
217
218
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
219
    # models for the time being.
220
221
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
222
223
}

224
225
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
226
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
227
228
229
230
231
232
233
234
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
235
236
237
238
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
239
240
241
242
243
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
244
    # [Auto-converted (see adapters.py)]
245
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
246
247
}

248
_MULTIMODAL_MODELS = {
249
    # [Decoder-only]
250
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
251
252
253
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
254
    ),
255
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
256
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
257
258
259
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
260
    ),
261
262
263
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
264
    ),
265
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
266
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
267
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
268
269
270
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
271
    ),
272
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
273
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
274
275
276
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
277
    ),
278
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
279
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
280
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
281
282
283
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
284
    ),
285
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
286
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
287
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
288
289
290
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
291
    ),
292
293
294
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
295
    ),
296
297
298
299
300
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
301
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
302
303
304
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
305
    ),
306
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
307
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
308
309
310
311
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
312
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
313
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
314
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
315
316
317
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
318
    ),
319
320
321
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
322
    ),
323
324
325
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
326
    ),
327
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
328
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
329
330
331
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
332
    ),
333
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
334
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
335
336
337
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
338
    ),
339
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
340
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
341
    "Ovis": ("ovis", "Ovis"),
342
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
343
344
345
346
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
347
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
348
349
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
350
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
351
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
352
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
353
354
355
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
356
    ),
357
358
359
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
360
    ),
361
362
363
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
364
    ),
365
366
367
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
368
    ),
369
370
371
372
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
373
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
374
375
376
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
377
    ),
378
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
379
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
380
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
381
382
383
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
384
    ),
385
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
386
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
387
    # [Encoder-decoder]
388
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
389
}
390
391

_SPECULATIVE_DECODING_MODELS = {
392
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
393
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
394
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
395
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
396
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
397
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
398
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
399
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
400
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
401
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
402
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
403
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
404
    "MedusaModel": ("medusa", "Medusa"),
405
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
406
407
408
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
409
}
410

411
_TRANSFORMERS_SUPPORTED_MODELS = {
412
413
414
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
415
416
417
418
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
419
420
421
}

_TRANSFORMERS_BACKEND_MODELS = {
422
    # Text generation models
423
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
442
    "TransformersForSequenceClassification": (
443
        "transformers",
444
        "TransformersForSequenceClassification",
445
    ),
446
    "TransformersMoEForSequenceClassification": (
447
        "transformers",
448
        "TransformersMoEForSequenceClassification",
449
    ),
450
451
452
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
453
    ),
454
}
455

456
_VLLM_MODELS = {
457
    **_TEXT_GENERATION_MODELS,
458
    **_EMBEDDING_MODELS,
459
    **_CROSS_ENCODER_MODELS,
460
    **_MULTIMODAL_MODELS,
461
    **_SPECULATIVE_DECODING_MODELS,
462
463
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
464
465
}

466
467
468
469
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
470
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
471

472
_PREVIOUSLY_SUPPORTED_MODELS = {
473
    "MotifForCausalLM": "0.10.2",
474
    "Phi3SmallForCausalLM": "0.9.2",
475
    "Phi4FlashForCausalLM": "0.10.2",
476
477
478
479
480
481
482
483
484
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
485

486

487
488
@dataclass(frozen=True)
class _ModelInfo:
489
    architecture: str
490
    is_text_generation_model: bool
491
    is_pooling_model: bool
492
    default_pooling_type: str
493
    supports_cross_encoding: bool
494
    supports_multimodal: bool
495
    supports_multimodal_raw_input_only: bool
496
    supports_multimodal_encoder_tp_data: bool
497
    supports_pp: bool
498
499
    has_inner_state: bool
    is_attention_free: bool
500
    is_hybrid: bool
501
    has_noops: bool
502
    supports_mamba_prefix_caching: bool
503
    supports_transcription: bool
504
    supports_transcription_only: bool
505
506

    @staticmethod
507
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
508
        return _ModelInfo(
509
            architecture=model.__name__,
510
            is_text_generation_model=is_text_generation_model(model),
511
            is_pooling_model=is_pooling_model(model),
512
            default_pooling_type=get_default_pooling_type(model),
513
            supports_cross_encoding=supports_cross_encoding(model),
514
            supports_multimodal=supports_multimodal(model),
515
516
517
518
519
520
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
521
            supports_pp=supports_pp(model),
522
523
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
524
            is_hybrid=is_hybrid(model),
525
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
526
            supports_transcription=supports_transcription(model),
527
528
529
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
530
            has_noops=has_noops(model),
531
        )
532
533


534
535
536
537
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
538

539
    @abstractmethod
540
    def load_model_cls(self) -> type[nn.Module]:
541
        raise NotImplementedError
542
543


544
545
546
547
548
549
550
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
551
    model_cls: type[nn.Module]
552
553

    @staticmethod
554
    def from_model_cls(model_cls: type[nn.Module]):
555
556
557
558
559
560
561
562
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

563
    def load_model_cls(self) -> type[nn.Module]:
564
565
566
567
568
569
570
571
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
572

573
574
575
    module_name: str
    class_name: str

576
577
578
579
580
581
582
583
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

584
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
585
586
        try:
            try:
587
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
588
589
590
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
591
592
593
594
595
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
596
597
598
                return None

            if mi_dict["hash"] != module_hash:
599
600
601
602
603
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
604
605
606
607
608
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
609
            logger.debug(
610
611
612
613
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
614
615
            return None

616
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
617
618
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
619

620
621
622
623
624
625
626
627
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
628
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
629
630
631
632
633
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
634
    def inspect_model_cls(self) -> _ModelInfo:
635
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
636
        module_hash = None
637

638
639
        if model_path.exists():
            with open(model_path, "rb") as f:
640
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
641
642
643

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
644
645
646
647
648
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
649
650
                return mi
            else:
651
652
653
654
655
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
656
657
658

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
659
660
661
662
663
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
664
665

        # save cache file
666
667
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
668
669

        return mi
670

671
    def load_model_cls(self) -> type[nn.Module]:
672
673
674
675
676
677
678
679
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
680
) -> type[nn.Module] | None:
681
    from vllm.platforms import current_platform
682

683
    current_platform.verify_model_arch(model_arch)
684
685
686
    try:
        return model.load_model_cls()
    except Exception:
687
        logger.exception("Error in loading model architecture '%s'", model_arch)
688
        return None
689
690


691
692
693
694
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
695
) -> _ModelInfo | None:
696
697
698
    try:
        return model.inspect_model_cls()
    except Exception:
699
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
700
        return None
701
702


703
704
705
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
706
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
707

708
    def get_supported_archs(self) -> Set[str]:
709
        return self.models.keys()
710

711
712
713
    def register_model(
        self,
        model_arch: str,
714
        model_cls: type[nn.Module] | str,
715
    ) -> None:
716
717
718
        """
        Register an external model to be used in vLLM.

719
        `model_cls` can be either:
720

721
        - A [`torch.nn.Module`][] class directly referencing the model.
722
        - A string in the format `<module>:<class>` which can be used to
723
724
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
725
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
726
        """
727
728
729
730
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

731
        if model_arch in self.models:
732
733
            logger.warning(
                "Model architecture %s is already registered, and will be "
734
735
736
737
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
738
739
740
741
742
743

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
744

745
            model = _LazyRegisteredModel(*split_str)
746
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
747
            model = _RegisteredModel.from_model_cls(model_cls)
748
        else:
749
750
751
752
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
753
            raise TypeError(msg)
754

755
        self.models[model_arch] = model
756

757
    def _raise_for_unsupported(self, architectures: list[str]):
758
        all_supported_archs = self.get_supported_archs()
759

760
761
762
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
763
764
                "to be inspected. Please check the logs for more details."
            )
765

766
767
768
769
770
771
772
773
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
774
775
                    "use this model architecture."
                )
776

777
778
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
779
780
            f"Supported architectures: {all_supported_archs}"
        )
781

782
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
783
784
        if model_arch not in self.models:
            return None
785

786
        return _try_load_model_cls(model_arch, self.models[model_arch])
787

788
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
789
790
        if model_arch not in self.models:
            return None
791

792
793
794
795
796
797
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
798
    ) -> str | None:
799
800
801
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

802
803
804
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
838
                if model_config.model_impl != "transformers":
839
840
841
842
843
844
845
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
846
847
                    "'auto_map' (relevant if the model is custom)."
                )
848
849

        if not model_module.is_backend_compatible():
850
            if model_config.model_impl != "transformers":
851
                return None
852

853
854
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
855
856
                "is not compatible with vLLM."
            )
857

858
        return model_config._get_transformers_backend_cls()
859

860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
885

886
887
    def inspect_model_cls(
        self,
888
        architectures: str | list[str],
889
        model_config: ModelConfig,
890
    ) -> tuple[_ModelInfo, str]:
891
892
        if isinstance(architectures, str):
            architectures = [architectures]
893
894
        if not architectures:
            raise ValueError("No model architectures are specified")
895
896

        # Require transformers impl
897
        if model_config.model_impl == "transformers":
898
            arch = self._try_resolve_transformers(architectures[0], model_config)
899
900
901
902
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
903
        elif model_config.model_impl == "terratorch":
904
905
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
906

907
        # Fallback to transformers impl (after resolving convert_type)
908
909
910
911
912
913
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
914
915
916
917
918
919
920
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
921
            model_info = self._try_inspect_model_cls(normalized_arch)
922
            if model_info is not None:
923
                return (model_info, arch)
924

925
        # Fallback to transformers impl (before resolving runner_type)
926
927
928
929
930
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
931
932
933
934
935
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

936
        return self._raise_for_unsupported(architectures)
937

938
939
    def resolve_model_cls(
        self,
940
        architectures: str | list[str],
941
        model_config: ModelConfig,
942
    ) -> tuple[type[nn.Module], str]:
943
944
        if isinstance(architectures, str):
            architectures = [architectures]
945
946
        if not architectures:
            raise ValueError("No model architectures are specified")
947
948

        # Require transformers impl
949
        if model_config.model_impl == "transformers":
950
            arch = self._try_resolve_transformers(architectures[0], model_config)
951
952
953
954
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
955
        elif model_config.model_impl == "terratorch":
956
957
958
959
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
960

961
        # Fallback to transformers impl (after resolving convert_type)
962
963
964
965
966
967
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
968
969
970
971
972
973
974
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
975
            model_cls = self._try_load_model_cls(normalized_arch)
976
977
            if model_cls is not None:
                return (model_cls, arch)
978

979
        # Fallback to transformers impl (before resolving runner_type)
980
981
982
983
984
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
985
986
987
988
989
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

990
        return self._raise_for_unsupported(architectures)
991

992
993
    def is_text_generation_model(
        self,
994
        architectures: str | list[str],
995
        model_config: ModelConfig,
996
    ) -> bool:
997
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
998
        return model_cls.is_text_generation_model
999

1000
    def is_pooling_model(
1001
        self,
1002
        architectures: str | list[str],
1003
        model_config: ModelConfig,
1004
    ) -> bool:
1005
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1006
        return model_cls.is_pooling_model
1007

1008
1009
    def is_cross_encoder_model(
        self,
1010
        architectures: str | list[str],
1011
        model_config: ModelConfig,
1012
    ) -> bool:
1013
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1014
        return model_cls.supports_cross_encoding
1015

1016
1017
    def is_multimodal_model(
        self,
1018
        architectures: str | list[str],
1019
        model_config: ModelConfig,
1020
    ) -> bool:
1021
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1022
        return model_cls.supports_multimodal
1023

1024
    def is_multimodal_raw_input_only_model(
1025
        self,
1026
        architectures: str | list[str],
1027
        model_config: ModelConfig,
1028
    ) -> bool:
1029
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1030
        return model_cls.supports_multimodal_raw_input_only
1031

1032
1033
    def is_pp_supported_model(
        self,
1034
        architectures: str | list[str],
1035
        model_config: ModelConfig,
1036
    ) -> bool:
1037
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1038
        return model_cls.supports_pp
1039

1040
1041
    def model_has_inner_state(
        self,
1042
        architectures: str | list[str],
1043
        model_config: ModelConfig,
1044
    ) -> bool:
1045
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1046
        return model_cls.has_inner_state
1047

1048
1049
    def is_attention_free_model(
        self,
1050
        architectures: str | list[str],
1051
        model_config: ModelConfig,
1052
    ) -> bool:
1053
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1054
        return model_cls.is_attention_free
1055

1056
1057
    def is_hybrid_model(
        self,
1058
        architectures: str | list[str],
1059
        model_config: ModelConfig,
1060
    ) -> bool:
1061
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1062
1063
        return model_cls.is_hybrid

1064
1065
    def is_noops_model(
        self,
1066
        architectures: str | list[str],
1067
        model_config: ModelConfig,
1068
    ) -> bool:
1069
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1070
1071
        return model_cls.has_noops

1072
1073
    def is_transcription_model(
        self,
1074
        architectures: str | list[str],
1075
        model_config: ModelConfig,
1076
    ) -> bool:
1077
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1078
1079
        return model_cls.supports_transcription

1080
1081
    def is_transcription_only_model(
        self,
1082
        architectures: str | list[str],
1083
        model_config: ModelConfig,
1084
    ) -> bool:
1085
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1086
1087
        return model_cls.supports_transcription_only

1088

1089
1090
1091
1092
1093
1094
1095
1096
1097
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1098
1099
1100
1101
1102

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1103
1104
1105
1106
1107
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1108
        # `cloudpickle` allows pickling lambda functions directly
1109
        import cloudpickle
1110

1111
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1112
1113
1114

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1115
1116
1117
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1118
1119
1120
1121
1122
1123

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1124
1125
1126
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1127

1128
        with open(output_filepath, "rb") as f:
1129
1130
1131
1132
1133
1134
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1135

1136
1137
1138
1139
1140
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1141
1142
1143

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1144
1145
1146


if __name__ == "__main__":
1147
    _run()