registry.py 43.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
53
54
55

logger = init_logger(__name__)

56
57
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
58
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
59
60
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
61
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
62
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
63
64
65
66
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
67
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
68
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
69
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
70
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
71
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
72
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
73
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
74
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
76
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
77
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
78
79
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
80
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
81
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
83
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
84
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
85
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
86
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
87
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
88
89
90
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
91
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
92
93
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
94
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
95
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
96
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
97
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
98
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
100
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
101
102
103
104
105
106
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
107
108
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
109
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
110
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
111
112
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
113
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
114
115
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
116
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
117
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
118
119
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
120
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
121
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
122
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
123
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
124
125
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
126
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
127
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
128
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
129
130
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
131
132
133
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
134
135
136
137
138
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
139
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
140
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
141
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
142
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
143
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
145
146
147
148
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
152
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
153
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
154
155
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
156
157
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
158
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
159
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
160
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
161
162
163
164
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
165
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
166
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
167
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
168
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
169
170
171
}

_EMBEDDING_MODELS = {
172
    # [Text-only]
173
    "BertModel": ("bert", "BertEmbeddingModel"),
174
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
175
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
176
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
177
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
178
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
179
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
180
    "GritLM": ("gritlm", "GritLM"),
181
182
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
183
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
184
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
185
    "LlamaModel": ("llama", "LlamaForCausalLM"),
186
187
    **{
        # Multiple models share the same architecture, so we include them all
188
189
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
190
191
        if arch == "LlamaForCausalLM"
    },
192
    "MistralModel": ("llama", "LlamaForCausalLM"),
193
    "ModernBertModel": ("modernbert", "ModernBertModel"),
194
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
195
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
196
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
197
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
198
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
199
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
200
201
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
204
    # [Multimodal]
205
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
206
207
208
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
209
    ),
210
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
211
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
212
213
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
214
    # models for the time being.
215
216
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
217
218
}

219
220
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
221
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
222
223
224
225
226
227
228
229
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
230
231
232
233
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
234
235
236
237
238
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
239
    # [Auto-converted (see adapters.py)]
240
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
241
242
}

243
_MULTIMODAL_MODELS = {
244
    # [Decoder-only]
245
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
246
247
248
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
249
    ),
250
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
251
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
252
253
254
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
255
    ),
256
257
258
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
259
    ),
260
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
261
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
262
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
263
264
265
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
266
    ),
267
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
268
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
269
270
271
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
272
    ),
273
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
274
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
275
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
276
277
278
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
279
    ),
280
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
281
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
282
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
283
284
285
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
286
    ),
287
288
289
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
290
    ),
291
292
293
294
295
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
296
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
297
298
299
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
300
    ),
301
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
302
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
303
304
305
306
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
307
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
308
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
309
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
310
311
312
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
313
    ),
314
315
316
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
317
    ),
318
319
320
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
321
    ),
322
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
323
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
324
325
326
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
327
    ),
328
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
329
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
330
331
332
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
333
    ),
334
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
335
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
336
    "Ovis": ("ovis", "Ovis"),
337
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
338
339
340
341
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
342
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
343
344
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
345
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
346
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
347
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
348
349
350
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
351
    ),
352
353
354
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
355
    ),
356
357
358
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
359
    ),
360
361
362
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
363
    ),
364
365
366
367
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
368
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
369
370
371
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
372
    ),
373
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
374
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
375
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
376
377
378
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
379
    ),
380
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
381
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
382
    # [Encoder-decoder]
383
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
384
}
385
386

_SPECULATIVE_DECODING_MODELS = {
387
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
388
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
389
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
390
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
391
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
392
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
393
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
394
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
395
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
396
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
397
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
398
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
399
    "MedusaModel": ("medusa", "Medusa"),
400
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
401
402
403
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
404
}
405

406
_TRANSFORMERS_SUPPORTED_MODELS = {
407
408
409
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
410
411
412
413
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
414
415
416
}

_TRANSFORMERS_BACKEND_MODELS = {
417
    # Text generation models
418
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
437
    "TransformersForSequenceClassification": (
438
        "transformers",
439
        "TransformersForSequenceClassification",
440
    ),
441
    "TransformersMoEForSequenceClassification": (
442
        "transformers",
443
        "TransformersMoEForSequenceClassification",
444
    ),
445
446
447
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
448
    ),
449
}
450

451
_VLLM_MODELS = {
452
    **_TEXT_GENERATION_MODELS,
453
    **_EMBEDDING_MODELS,
454
    **_CROSS_ENCODER_MODELS,
455
    **_MULTIMODAL_MODELS,
456
    **_SPECULATIVE_DECODING_MODELS,
457
458
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
459
460
}

461
462
463
464
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
465
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
466

467
_PREVIOUSLY_SUPPORTED_MODELS = {
468
    "MotifForCausalLM": "0.10.2",
469
    "Phi3SmallForCausalLM": "0.9.2",
470
    "Phi4FlashForCausalLM": "0.10.2",
471
472
473
474
475
476
477
478
479
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
480

481

482
483
@dataclass(frozen=True)
class _ModelInfo:
484
    architecture: str
485
    is_text_generation_model: bool
486
    is_pooling_model: bool
487
    default_pooling_type: str
488
    supports_cross_encoding: bool
489
    supports_multimodal: bool
490
    supports_multimodal_raw_input_only: bool
491
    supports_multimodal_encoder_tp_data: bool
492
    supports_pp: bool
493
494
    has_inner_state: bool
    is_attention_free: bool
495
    is_hybrid: bool
496
    has_noops: bool
497
    supports_transcription: bool
498
    supports_transcription_only: bool
499
500

    @staticmethod
501
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
502
        return _ModelInfo(
503
            architecture=model.__name__,
504
            is_text_generation_model=is_text_generation_model(model),
505
            is_pooling_model=is_pooling_model(model),
506
            default_pooling_type=get_default_pooling_type(model),
507
            supports_cross_encoding=supports_cross_encoding(model),
508
            supports_multimodal=supports_multimodal(model),
509
510
511
512
513
514
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
515
            supports_pp=supports_pp(model),
516
517
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
518
            is_hybrid=is_hybrid(model),
519
            supports_transcription=supports_transcription(model),
520
521
522
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
523
            has_noops=has_noops(model),
524
        )
525
526


527
528
529
530
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
531

532
    @abstractmethod
533
    def load_model_cls(self) -> type[nn.Module]:
534
        raise NotImplementedError
535
536


537
538
539
540
541
542
543
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
544
    model_cls: type[nn.Module]
545
546

    @staticmethod
547
    def from_model_cls(model_cls: type[nn.Module]):
548
549
550
551
552
553
554
555
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

556
    def load_model_cls(self) -> type[nn.Module]:
557
558
559
560
561
562
563
564
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
565

566
567
568
    module_name: str
    class_name: str

569
570
571
572
573
574
575
576
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

577
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
578
579
        try:
            try:
580
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
581
582
583
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
584
585
586
587
588
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
589
590
591
                return None

            if mi_dict["hash"] != module_hash:
592
593
594
595
596
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
597
598
599
600
601
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
602
            logger.debug(
603
604
605
606
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
607
608
            return None

609
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
610
611
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
612

613
614
615
616
617
618
619
620
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
621
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
622
623
624
625
626
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
627
    def inspect_model_cls(self) -> _ModelInfo:
628
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
629
        module_hash = None
630

631
632
        if model_path.exists():
            with open(model_path, "rb") as f:
633
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
634
635
636

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
637
638
639
640
641
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
642
643
                return mi
            else:
644
645
646
647
648
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
649
650
651

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
652
653
654
655
656
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
657
658

        # save cache file
659
660
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
661
662

        return mi
663

664
    def load_model_cls(self) -> type[nn.Module]:
665
666
667
668
669
670
671
672
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
673
) -> type[nn.Module] | None:
674
    from vllm.platforms import current_platform
675

676
    current_platform.verify_model_arch(model_arch)
677
678
679
    try:
        return model.load_model_cls()
    except Exception:
680
        logger.exception("Error in loading model architecture '%s'", model_arch)
681
        return None
682
683


684
685
686
687
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
688
) -> _ModelInfo | None:
689
690
691
    try:
        return model.inspect_model_cls()
    except Exception:
692
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
693
        return None
694
695


696
697
698
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
699
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
700

701
    def get_supported_archs(self) -> Set[str]:
702
        return self.models.keys()
703

704
705
706
    def register_model(
        self,
        model_arch: str,
707
        model_cls: type[nn.Module] | str,
708
    ) -> None:
709
710
711
        """
        Register an external model to be used in vLLM.

712
        `model_cls` can be either:
713

714
        - A [`torch.nn.Module`][] class directly referencing the model.
715
        - A string in the format `<module>:<class>` which can be used to
716
717
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
718
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
719
        """
720
721
722
723
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

724
        if model_arch in self.models:
725
726
            logger.warning(
                "Model architecture %s is already registered, and will be "
727
728
729
730
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
731
732
733
734
735
736

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
737

738
            model = _LazyRegisteredModel(*split_str)
739
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
740
            model = _RegisteredModel.from_model_cls(model_cls)
741
        else:
742
743
744
745
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
746
            raise TypeError(msg)
747

748
        self.models[model_arch] = model
749

750
    def _raise_for_unsupported(self, architectures: list[str]):
751
        all_supported_archs = self.get_supported_archs()
752

753
754
755
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
756
757
                "to be inspected. Please check the logs for more details."
            )
758

759
760
761
762
763
764
765
766
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
767
768
                    "use this model architecture."
                )
769

770
771
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
772
773
            f"Supported architectures: {all_supported_archs}"
        )
774

775
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
776
777
        if model_arch not in self.models:
            return None
778

779
        return _try_load_model_cls(model_arch, self.models[model_arch])
780

781
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
782
783
        if model_arch not in self.models:
            return None
784

785
786
787
788
789
790
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
791
    ) -> str | None:
792
793
794
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

795
796
797
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
831
                if model_config.model_impl != "transformers":
832
833
834
835
836
837
838
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
839
840
                    "'auto_map' (relevant if the model is custom)."
                )
841
842

        if not model_module.is_backend_compatible():
843
            if model_config.model_impl != "transformers":
844
                return None
845

846
847
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
848
849
                "is not compatible with vLLM."
            )
850

851
        return model_config._get_transformers_backend_cls()
852

853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
878

879
880
    def inspect_model_cls(
        self,
881
        architectures: str | list[str],
882
        model_config: ModelConfig,
883
    ) -> tuple[_ModelInfo, str]:
884
885
        if isinstance(architectures, str):
            architectures = [architectures]
886
887
        if not architectures:
            raise ValueError("No model architectures are specified")
888
889

        # Require transformers impl
890
        if model_config.model_impl == "transformers":
891
            arch = self._try_resolve_transformers(architectures[0], model_config)
892
893
894
895
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
896
        elif model_config.model_impl == "terratorch":
897
898
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
899

900
        # Fallback to transformers impl (after resolving convert_type)
901
902
903
904
905
906
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
907
908
909
910
911
912
913
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
914
            model_info = self._try_inspect_model_cls(normalized_arch)
915
            if model_info is not None:
916
                return (model_info, arch)
917

918
        # Fallback to transformers impl (before resolving runner_type)
919
920
921
922
923
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
924
925
926
927
928
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

929
        return self._raise_for_unsupported(architectures)
930

931
932
    def resolve_model_cls(
        self,
933
        architectures: str | list[str],
934
        model_config: ModelConfig,
935
    ) -> tuple[type[nn.Module], str]:
936
937
        if isinstance(architectures, str):
            architectures = [architectures]
938
939
        if not architectures:
            raise ValueError("No model architectures are specified")
940
941

        # Require transformers impl
942
        if model_config.model_impl == "transformers":
943
            arch = self._try_resolve_transformers(architectures[0], model_config)
944
945
946
947
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
948
        elif model_config.model_impl == "terratorch":
949
950
951
952
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
953

954
        # Fallback to transformers impl (after resolving convert_type)
955
956
957
958
959
960
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
961
962
963
964
965
966
967
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
968
            model_cls = self._try_load_model_cls(normalized_arch)
969
970
            if model_cls is not None:
                return (model_cls, arch)
971

972
        # Fallback to transformers impl (before resolving runner_type)
973
974
975
976
977
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
978
979
980
981
982
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

983
        return self._raise_for_unsupported(architectures)
984

985
986
    def is_text_generation_model(
        self,
987
        architectures: str | list[str],
988
        model_config: ModelConfig,
989
    ) -> bool:
990
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
991
        return model_cls.is_text_generation_model
992

993
    def is_pooling_model(
994
        self,
995
        architectures: str | list[str],
996
        model_config: ModelConfig,
997
    ) -> bool:
998
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
999
        return model_cls.is_pooling_model
1000

1001
1002
    def is_cross_encoder_model(
        self,
1003
        architectures: str | list[str],
1004
        model_config: ModelConfig,
1005
    ) -> bool:
1006
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1007
        return model_cls.supports_cross_encoding
1008

1009
1010
    def is_multimodal_model(
        self,
1011
        architectures: str | list[str],
1012
        model_config: ModelConfig,
1013
    ) -> bool:
1014
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1015
        return model_cls.supports_multimodal
1016

1017
    def is_multimodal_raw_input_only_model(
1018
        self,
1019
        architectures: str | list[str],
1020
        model_config: ModelConfig,
1021
    ) -> bool:
1022
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1023
        return model_cls.supports_multimodal_raw_input_only
1024

1025
1026
    def is_pp_supported_model(
        self,
1027
        architectures: str | list[str],
1028
        model_config: ModelConfig,
1029
    ) -> bool:
1030
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1031
        return model_cls.supports_pp
1032

1033
1034
    def model_has_inner_state(
        self,
1035
        architectures: str | list[str],
1036
        model_config: ModelConfig,
1037
    ) -> bool:
1038
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1039
        return model_cls.has_inner_state
1040

1041
1042
    def is_attention_free_model(
        self,
1043
        architectures: str | list[str],
1044
        model_config: ModelConfig,
1045
    ) -> bool:
1046
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1047
        return model_cls.is_attention_free
1048

1049
1050
    def is_hybrid_model(
        self,
1051
        architectures: str | list[str],
1052
        model_config: ModelConfig,
1053
    ) -> bool:
1054
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1055
1056
        return model_cls.is_hybrid

1057
1058
    def is_noops_model(
        self,
1059
        architectures: str | list[str],
1060
        model_config: ModelConfig,
1061
    ) -> bool:
1062
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1063
1064
        return model_cls.has_noops

1065
1066
    def is_transcription_model(
        self,
1067
        architectures: str | list[str],
1068
        model_config: ModelConfig,
1069
    ) -> bool:
1070
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1071
1072
        return model_cls.supports_transcription

1073
1074
    def is_transcription_only_model(
        self,
1075
        architectures: str | list[str],
1076
        model_config: ModelConfig,
1077
    ) -> bool:
1078
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1079
1080
        return model_cls.supports_transcription_only

1081

1082
1083
1084
1085
1086
1087
1088
1089
1090
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1091
1092
1093
1094
1095

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1096
1097
1098
1099
1100
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1101
        # `cloudpickle` allows pickling lambda functions directly
1102
        import cloudpickle
1103

1104
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1105
1106
1107

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1108
1109
1110
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1111
1112
1113
1114
1115
1116

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1117
1118
1119
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1120

1121
        with open(output_filepath, "rb") as f:
1122
1123
1124
1125
1126
1127
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1128

1129
1130
1131
1132
1133
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1134
1135
1136

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1137
1138
1139


if __name__ == "__main__":
1140
    _run()