registry.py 43.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
53
54
55

logger = init_logger(__name__)

56
57
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
58
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
59
60
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
61
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
62
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
63
64
65
66
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
67
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
68
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
69
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
70
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
71
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
72
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
73
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
74
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
76
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
77
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
78
79
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
80
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
81
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
83
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
84
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
85
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
86
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
87
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
88
89
90
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
91
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
92
93
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
94
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
95
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
96
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
97
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
98
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
100
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
101
102
103
104
105
106
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
107
108
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
109
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
110
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
111
112
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
113
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
114
115
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
116
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
117
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
118
119
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
120
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
121
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
122
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
123
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
124
125
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
126
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
127
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
128
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
129
130
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
131
132
133
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
134
135
136
137
138
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
139
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
140
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
141
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
142
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
143
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
145
146
147
148
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
152
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
153
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
154
155
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
156
157
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
158
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
159
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
160
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
161
162
163
164
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
165
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
166
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
167
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
168
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
169
170
171
}

_EMBEDDING_MODELS = {
172
    # [Text-only]
173
    "BertModel": ("bert", "BertEmbeddingModel"),
174
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
175
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
176
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
177
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
178
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
179
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
180
    "GritLM": ("gritlm", "GritLM"),
181
182
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
183
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
184
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
185
    "LlamaModel": ("llama", "LlamaForCausalLM"),
186
187
    **{
        # Multiple models share the same architecture, so we include them all
188
189
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
190
191
        if arch == "LlamaForCausalLM"
    },
192
    "MistralModel": ("llama", "LlamaForCausalLM"),
193
    "ModernBertModel": ("modernbert", "ModernBertModel"),
194
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
195
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
196
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
197
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
198
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
199
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
200
201
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
204
    # [Multimodal]
205
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
206
207
208
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
209
    ),
210
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
211
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
212
213
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
214
    # models for the time being.
215
216
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
217
218
}

219
220
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
221
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
222
223
224
225
226
227
228
229
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
230
231
232
233
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
234
235
236
237
238
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
239
    # [Auto-converted (see adapters.py)]
240
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
241
242
}

243
_MULTIMODAL_MODELS = {
244
    # [Decoder-only]
245
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
246
247
248
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
249
    ),
250
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
251
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
252
253
254
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
255
    ),
256
257
258
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
259
    ),
260
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
261
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
262
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
263
264
265
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
266
    ),
267
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
268
269
270
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
271
    ),
272
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
273
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
274
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
275
276
277
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
278
    ),
279
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
280
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
281
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
282
283
284
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
285
    ),
286
287
288
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
289
    ),
290
291
292
293
294
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
295
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
296
297
298
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
299
    ),
300
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
301
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
302
303
304
305
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
306
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
307
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
308
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
309
310
311
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
312
    ),
313
314
315
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
316
    ),
317
318
319
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
320
    ),
321
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
322
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
323
324
325
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
326
    ),
327
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
328
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
329
330
331
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
332
    ),
333
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
334
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
335
    "Ovis": ("ovis", "Ovis"),
336
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
337
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
338
339
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
340
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
341
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
342
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
343
344
345
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
346
    ),
347
348
349
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
350
    ),
351
352
353
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
354
    ),
355
356
357
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
358
    ),
359
360
361
362
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
363
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
364
365
366
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
367
    ),
368
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
369
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
370
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
371
372
373
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
374
    ),
375
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
376
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
377
    # [Encoder-decoder]
378
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
379
}
380
381

_SPECULATIVE_DECODING_MODELS = {
382
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
383
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
384
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
385
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
386
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
387
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
388
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
389
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
390
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
391
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
392
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
393
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
394
    "MedusaModel": ("medusa", "Medusa"),
395
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
396
397
398
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
399
}
400

401
_TRANSFORMERS_SUPPORTED_MODELS = {
402
403
404
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
405
406
407
408
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
409
410
411
412
413
414
415
416
    "Gemma3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "PaliGemmaForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
417
418
419
}

_TRANSFORMERS_BACKEND_MODELS = {
420
    # Text generation models
421
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
440
    "TransformersForSequenceClassification": (
441
        "transformers",
442
        "TransformersForSequenceClassification",
443
    ),
444
    "TransformersMoEForSequenceClassification": (
445
        "transformers",
446
        "TransformersMoEForSequenceClassification",
447
    ),
448
449
450
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
451
    ),
452
}
453

454
_VLLM_MODELS = {
455
    **_TEXT_GENERATION_MODELS,
456
    **_EMBEDDING_MODELS,
457
    **_CROSS_ENCODER_MODELS,
458
    **_MULTIMODAL_MODELS,
459
    **_SPECULATIVE_DECODING_MODELS,
460
461
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
462
463
}

464
465
466
467
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
468
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
469

470
_PREVIOUSLY_SUPPORTED_MODELS = {
471
    "MotifForCausalLM": "0.10.2",
472
    "Phi3SmallForCausalLM": "0.9.2",
473
    "Phi4FlashForCausalLM": "0.10.2",
474
475
476
477
478
479
480
481
482
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
483

484

485
486
@dataclass(frozen=True)
class _ModelInfo:
487
    architecture: str
488
    is_text_generation_model: bool
489
    is_pooling_model: bool
490
    default_pooling_type: str
491
    supports_cross_encoding: bool
492
    supports_multimodal: bool
493
    supports_multimodal_raw_input_only: bool
494
    supports_multimodal_encoder_tp_data: bool
495
    supports_pp: bool
496
497
    has_inner_state: bool
    is_attention_free: bool
498
    is_hybrid: bool
499
    has_noops: bool
500
    supports_transcription: bool
501
    supports_transcription_only: bool
502
503

    @staticmethod
504
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
505
        return _ModelInfo(
506
            architecture=model.__name__,
507
            is_text_generation_model=is_text_generation_model(model),
508
            is_pooling_model=is_pooling_model(model),
509
            default_pooling_type=get_default_pooling_type(model),
510
            supports_cross_encoding=supports_cross_encoding(model),
511
            supports_multimodal=supports_multimodal(model),
512
513
514
515
516
517
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
518
            supports_pp=supports_pp(model),
519
520
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
521
            is_hybrid=is_hybrid(model),
522
            supports_transcription=supports_transcription(model),
523
524
525
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
526
            has_noops=has_noops(model),
527
        )
528
529


530
531
532
533
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
534

535
    @abstractmethod
536
    def load_model_cls(self) -> type[nn.Module]:
537
        raise NotImplementedError
538
539


540
541
542
543
544
545
546
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
547
    model_cls: type[nn.Module]
548
549

    @staticmethod
550
    def from_model_cls(model_cls: type[nn.Module]):
551
552
553
554
555
556
557
558
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

559
    def load_model_cls(self) -> type[nn.Module]:
560
561
562
563
564
565
566
567
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
568

569
570
571
    module_name: str
    class_name: str

572
573
574
575
576
577
578
579
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

580
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
581
582
        try:
            try:
583
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
584
585
586
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
587
588
589
590
591
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
592
593
594
                return None

            if mi_dict["hash"] != module_hash:
595
596
597
598
599
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
600
601
602
603
604
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
605
            logger.debug(
606
607
608
609
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
610
611
            return None

612
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
613
614
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
615

616
617
618
619
620
621
622
623
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
624
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
625
626
627
628
629
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
630
    def inspect_model_cls(self) -> _ModelInfo:
631
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
632
        module_hash = None
633

634
635
        if model_path.exists():
            with open(model_path, "rb") as f:
636
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
637
638
639

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
640
641
642
643
644
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
645
646
                return mi
            else:
647
648
649
650
651
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
652
653
654

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
655
656
657
658
659
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
660
661

        # save cache file
662
663
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
664
665

        return mi
666

667
    def load_model_cls(self) -> type[nn.Module]:
668
669
670
671
672
673
674
675
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
676
) -> type[nn.Module] | None:
677
    from vllm.platforms import current_platform
678

679
    current_platform.verify_model_arch(model_arch)
680
681
682
    try:
        return model.load_model_cls()
    except Exception:
683
        logger.exception("Error in loading model architecture '%s'", model_arch)
684
        return None
685
686


687
688
689
690
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
691
) -> _ModelInfo | None:
692
693
694
    try:
        return model.inspect_model_cls()
    except Exception:
695
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
696
        return None
697
698


699
700
701
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
702
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
703

704
    def get_supported_archs(self) -> Set[str]:
705
        return self.models.keys()
706

707
708
709
    def register_model(
        self,
        model_arch: str,
710
        model_cls: type[nn.Module] | str,
711
    ) -> None:
712
713
714
        """
        Register an external model to be used in vLLM.

715
        `model_cls` can be either:
716

717
        - A [`torch.nn.Module`][] class directly referencing the model.
718
        - A string in the format `<module>:<class>` which can be used to
719
720
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
721
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
722
        """
723
724
725
726
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

727
        if model_arch in self.models:
728
729
            logger.warning(
                "Model architecture %s is already registered, and will be "
730
731
732
733
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
734
735
736
737
738
739

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
740

741
            model = _LazyRegisteredModel(*split_str)
742
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
743
            model = _RegisteredModel.from_model_cls(model_cls)
744
        else:
745
746
747
748
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
749
            raise TypeError(msg)
750

751
        self.models[model_arch] = model
752

753
    def _raise_for_unsupported(self, architectures: list[str]):
754
        all_supported_archs = self.get_supported_archs()
755

756
757
758
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
759
760
                "to be inspected. Please check the logs for more details."
            )
761

762
763
764
765
766
767
768
769
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
770
771
                    "use this model architecture."
                )
772

773
774
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
775
776
            f"Supported architectures: {all_supported_archs}"
        )
777

778
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
779
780
        if model_arch not in self.models:
            return None
781

782
        return _try_load_model_cls(model_arch, self.models[model_arch])
783

784
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
785
786
        if model_arch not in self.models:
            return None
787

788
789
790
791
792
793
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
794
    ) -> str | None:
795
796
797
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

798
799
800
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
834
                if model_config.model_impl != "transformers":
835
836
837
838
839
840
841
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
842
843
                    "'auto_map' (relevant if the model is custom)."
                )
844
845

        if not model_module.is_backend_compatible():
846
            if model_config.model_impl != "transformers":
847
                return None
848

849
850
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
851
852
                "is not compatible with vLLM."
            )
853

854
        return model_config._get_transformers_backend_cls()
855

856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
881

882
883
    def inspect_model_cls(
        self,
884
        architectures: str | list[str],
885
        model_config: ModelConfig,
886
    ) -> tuple[_ModelInfo, str]:
887
888
        if isinstance(architectures, str):
            architectures = [architectures]
889
890
        if not architectures:
            raise ValueError("No model architectures are specified")
891
892

        # Require transformers impl
893
        if model_config.model_impl == "transformers":
894
            arch = self._try_resolve_transformers(architectures[0], model_config)
895
896
897
898
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
899
        elif model_config.model_impl == "terratorch":
900
901
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
902

903
        # Fallback to transformers impl (after resolving convert_type)
904
905
906
907
908
909
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
910
911
912
913
914
915
916
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
917
            model_info = self._try_inspect_model_cls(normalized_arch)
918
            if model_info is not None:
919
                return (model_info, arch)
920

921
        # Fallback to transformers impl (before resolving runner_type)
922
923
924
925
926
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
927
928
929
930
931
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

932
        return self._raise_for_unsupported(architectures)
933

934
935
    def resolve_model_cls(
        self,
936
        architectures: str | list[str],
937
        model_config: ModelConfig,
938
    ) -> tuple[type[nn.Module], str]:
939
940
        if isinstance(architectures, str):
            architectures = [architectures]
941
942
        if not architectures:
            raise ValueError("No model architectures are specified")
943
944

        # Require transformers impl
945
        if model_config.model_impl == "transformers":
946
            arch = self._try_resolve_transformers(architectures[0], model_config)
947
948
949
950
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
951
        elif model_config.model_impl == "terratorch":
952
953
954
955
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
956

957
        # Fallback to transformers impl (after resolving convert_type)
958
959
960
961
962
963
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
964
965
966
967
968
969
970
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
971
            model_cls = self._try_load_model_cls(normalized_arch)
972
973
            if model_cls is not None:
                return (model_cls, arch)
974

975
        # Fallback to transformers impl (before resolving runner_type)
976
977
978
979
980
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
981
982
983
984
985
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

986
        return self._raise_for_unsupported(architectures)
987

988
989
    def is_text_generation_model(
        self,
990
        architectures: str | list[str],
991
        model_config: ModelConfig,
992
    ) -> bool:
993
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
994
        return model_cls.is_text_generation_model
995

996
    def is_pooling_model(
997
        self,
998
        architectures: str | list[str],
999
        model_config: ModelConfig,
1000
    ) -> bool:
1001
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1002
        return model_cls.is_pooling_model
1003

1004
1005
    def is_cross_encoder_model(
        self,
1006
        architectures: str | list[str],
1007
        model_config: ModelConfig,
1008
    ) -> bool:
1009
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1010
        return model_cls.supports_cross_encoding
1011

1012
1013
    def is_multimodal_model(
        self,
1014
        architectures: str | list[str],
1015
        model_config: ModelConfig,
1016
    ) -> bool:
1017
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1018
        return model_cls.supports_multimodal
1019

1020
    def is_multimodal_raw_input_only_model(
1021
        self,
1022
        architectures: str | list[str],
1023
        model_config: ModelConfig,
1024
    ) -> bool:
1025
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1026
        return model_cls.supports_multimodal_raw_input_only
1027

1028
1029
    def is_pp_supported_model(
        self,
1030
        architectures: str | list[str],
1031
        model_config: ModelConfig,
1032
    ) -> bool:
1033
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1034
        return model_cls.supports_pp
1035

1036
1037
    def model_has_inner_state(
        self,
1038
        architectures: str | list[str],
1039
        model_config: ModelConfig,
1040
    ) -> bool:
1041
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1042
        return model_cls.has_inner_state
1043

1044
1045
    def is_attention_free_model(
        self,
1046
        architectures: str | list[str],
1047
        model_config: ModelConfig,
1048
    ) -> bool:
1049
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1050
        return model_cls.is_attention_free
1051

1052
1053
    def is_hybrid_model(
        self,
1054
        architectures: str | list[str],
1055
        model_config: ModelConfig,
1056
    ) -> bool:
1057
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1058
1059
        return model_cls.is_hybrid

1060
1061
    def is_noops_model(
        self,
1062
        architectures: str | list[str],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
1067
        return model_cls.has_noops

1068
1069
    def is_transcription_model(
        self,
1070
        architectures: str | list[str],
1071
        model_config: ModelConfig,
1072
    ) -> bool:
1073
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1074
1075
        return model_cls.supports_transcription

1076
1077
    def is_transcription_only_model(
        self,
1078
        architectures: str | list[str],
1079
        model_config: ModelConfig,
1080
    ) -> bool:
1081
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1082
1083
        return model_cls.supports_transcription_only

1084

1085
1086
1087
1088
1089
1090
1091
1092
1093
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1094
1095
1096
1097
1098

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1099
1100
1101
1102
1103
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1104
        # `cloudpickle` allows pickling lambda functions directly
1105
        import cloudpickle
1106

1107
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1108
1109
1110

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1111
1112
1113
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1114
1115
1116
1117
1118
1119

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1120
1121
1122
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1123

1124
        with open(output_filepath, "rb") as f:
1125
1126
1127
1128
1129
1130
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1131

1132
1133
1134
1135
1136
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1137
1138
1139

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1140
1141
1142


if __name__ == "__main__":
1143
    _run()