registry.py 43.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
53
54
55

logger = init_logger(__name__)

56
57
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
58
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
59
60
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
61
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
62
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
63
64
65
66
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
67
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
68
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
69
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
70
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
71
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
72
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
73
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
74
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
75
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
76
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
77
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
78
79
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
80
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
81
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
82
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
83
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
84
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
85
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
86
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
87
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
88
89
90
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
91
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
92
93
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
94
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
95
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
96
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
97
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
98
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
100
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
101
102
103
104
105
106
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
107
108
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
109
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
110
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
111
112
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
113
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
114
115
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
116
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
117
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
118
119
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
120
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
121
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
122
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
123
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
124
125
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
126
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
127
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
128
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
129
130
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
131
132
133
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
134
135
136
137
138
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
139
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
140
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
141
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
142
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
143
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
145
146
147
148
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
152
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
153
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
154
155
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
156
157
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
158
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
159
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
160
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
161
162
163
164
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
165
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
166
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
167
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
168
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
169
170
171
}

_EMBEDDING_MODELS = {
172
    # [Text-only]
173
    "BertModel": ("bert", "BertEmbeddingModel"),
174
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
175
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
176
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
177
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
178
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
179
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
180
    "GritLM": ("gritlm", "GritLM"),
181
182
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
183
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
184
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
185
    "LlamaModel": ("llama", "LlamaForCausalLM"),
186
187
    **{
        # Multiple models share the same architecture, so we include them all
188
189
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
190
191
        if arch == "LlamaForCausalLM"
    },
192
    "MistralModel": ("llama", "LlamaForCausalLM"),
193
    "ModernBertModel": ("modernbert", "ModernBertModel"),
194
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
195
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
196
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
197
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
198
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
199
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
200
201
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
204
    # [Multimodal]
205
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
206
207
208
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
209
    ),
210
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
211
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
212
213
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
214
    # models for the time being.
215
216
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
217
218
}

219
220
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
221
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
222
223
224
225
226
227
228
229
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
230
231
232
233
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
234
235
236
237
238
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
239
    # [Auto-converted (see adapters.py)]
240
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
241
242
}

243
_MULTIMODAL_MODELS = {
244
    # [Decoder-only]
245
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
246
247
248
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
249
    ),
250
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
251
252
253
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
254
    ),
255
256
257
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
258
    ),
259
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
260
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
261
262
263
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
264
    ),
265
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
266
267
268
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
269
    ),
270
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
271
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
272
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
273
274
275
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
276
    ),
277
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
278
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
279
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
280
281
282
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
283
    ),
284
285
286
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
287
    ),
288
289
290
291
292
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
293
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
294
295
296
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
297
    ),
298
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
299
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
300
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
301
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
302
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
303
304
305
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
306
    ),
307
308
309
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
310
    ),
311
312
313
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
314
    ),
315
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
316
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
317
318
319
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
320
    ),
321
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
322
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
323
324
325
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
326
    ),
327
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
328
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
329
    "Ovis": ("ovis", "Ovis"),
330
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
331
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
332
333
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
334
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
335
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
336
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
337
338
339
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
340
    ),
341
342
343
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
344
    ),
345
346
347
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
348
    ),
349
350
351
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
352
    ),
353
354
355
356
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
357
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
358
359
360
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
361
    ),
362
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
363
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
364
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
365
366
367
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
368
    ),
369
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
370
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
371
    # [Encoder-decoder]
372
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
373
}
374
375

_SPECULATIVE_DECODING_MODELS = {
376
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
377
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
378
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
379
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
380
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
381
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
382
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
383
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
384
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
385
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
386
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
387
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
388
    "MedusaModel": ("medusa", "Medusa"),
389
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
390
391
392
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
393
}
394

395
_TRANSFORMERS_SUPPORTED_MODELS = {
396
397
398
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
399
400
401
402
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
403
404
405
406
407
408
409
410
    "Gemma3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "PaliGemmaForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
411
412
413
}

_TRANSFORMERS_BACKEND_MODELS = {
414
    # Text generation models
415
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
434
    "TransformersForSequenceClassification": (
435
        "transformers",
436
        "TransformersForSequenceClassification",
437
    ),
438
    "TransformersMoEForSequenceClassification": (
439
        "transformers",
440
        "TransformersMoEForSequenceClassification",
441
    ),
442
443
444
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
445
    ),
446
}
447

448
_VLLM_MODELS = {
449
    **_TEXT_GENERATION_MODELS,
450
    **_EMBEDDING_MODELS,
451
    **_CROSS_ENCODER_MODELS,
452
    **_MULTIMODAL_MODELS,
453
    **_SPECULATIVE_DECODING_MODELS,
454
455
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
456
457
}

458
459
460
461
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
462
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
463

464
_PREVIOUSLY_SUPPORTED_MODELS = {
465
    "MotifForCausalLM": "0.10.2",
466
    "Phi3SmallForCausalLM": "0.9.2",
467
    "Phi4FlashForCausalLM": "0.10.2",
468
469
470
471
472
473
474
475
476
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
477

478

479
480
@dataclass(frozen=True)
class _ModelInfo:
481
    architecture: str
482
    is_text_generation_model: bool
483
    is_pooling_model: bool
484
    default_pooling_type: str
485
    supports_cross_encoding: bool
486
    supports_multimodal: bool
487
    supports_multimodal_raw_input_only: bool
488
    supports_multimodal_encoder_tp_data: bool
489
    supports_pp: bool
490
491
    has_inner_state: bool
    is_attention_free: bool
492
    is_hybrid: bool
493
    has_noops: bool
494
    supports_transcription: bool
495
    supports_transcription_only: bool
496
497

    @staticmethod
498
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
499
        return _ModelInfo(
500
            architecture=model.__name__,
501
            is_text_generation_model=is_text_generation_model(model),
502
            is_pooling_model=is_pooling_model(model),
503
            default_pooling_type=get_default_pooling_type(model),
504
            supports_cross_encoding=supports_cross_encoding(model),
505
            supports_multimodal=supports_multimodal(model),
506
507
508
509
510
511
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
512
            supports_pp=supports_pp(model),
513
514
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
515
            is_hybrid=is_hybrid(model),
516
            supports_transcription=supports_transcription(model),
517
518
519
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
520
            has_noops=has_noops(model),
521
        )
522
523


524
525
526
527
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
528

529
    @abstractmethod
530
    def load_model_cls(self) -> type[nn.Module]:
531
        raise NotImplementedError
532
533


534
535
536
537
538
539
540
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
541
    model_cls: type[nn.Module]
542
543

    @staticmethod
544
    def from_model_cls(model_cls: type[nn.Module]):
545
546
547
548
549
550
551
552
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

553
    def load_model_cls(self) -> type[nn.Module]:
554
555
556
557
558
559
560
561
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
562

563
564
565
    module_name: str
    class_name: str

566
567
568
569
570
571
572
573
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

574
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
575
576
        try:
            try:
577
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
578
579
580
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
581
582
583
584
585
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
586
587
588
                return None

            if mi_dict["hash"] != module_hash:
589
590
591
592
593
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
594
595
596
597
598
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
599
            logger.debug(
600
601
602
603
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
604
605
            return None

606
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
607
608
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
609

610
611
612
613
614
615
616
617
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
618
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
619
620
621
622
623
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
624
    def inspect_model_cls(self) -> _ModelInfo:
625
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
626
        module_hash = None
627

628
629
        if model_path.exists():
            with open(model_path, "rb") as f:
630
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
631
632
633

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
634
635
636
637
638
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
639
640
                return mi
            else:
641
642
643
644
645
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
646
647
648

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
649
650
651
652
653
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
654
655

        # save cache file
656
657
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
658
659

        return mi
660

661
    def load_model_cls(self) -> type[nn.Module]:
662
663
664
665
666
667
668
669
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
670
) -> type[nn.Module] | None:
671
    from vllm.platforms import current_platform
672

673
    current_platform.verify_model_arch(model_arch)
674
675
676
    try:
        return model.load_model_cls()
    except Exception:
677
        logger.exception("Error in loading model architecture '%s'", model_arch)
678
        return None
679
680


681
682
683
684
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
685
) -> _ModelInfo | None:
686
687
688
    try:
        return model.inspect_model_cls()
    except Exception:
689
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
690
        return None
691
692


693
694
695
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
696
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
697

698
    def get_supported_archs(self) -> Set[str]:
699
        return self.models.keys()
700

701
702
703
    def register_model(
        self,
        model_arch: str,
704
        model_cls: type[nn.Module] | str,
705
    ) -> None:
706
707
708
        """
        Register an external model to be used in vLLM.

709
        `model_cls` can be either:
710

711
        - A [`torch.nn.Module`][] class directly referencing the model.
712
        - A string in the format `<module>:<class>` which can be used to
713
714
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
715
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
716
        """
717
718
719
720
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

721
        if model_arch in self.models:
722
723
            logger.warning(
                "Model architecture %s is already registered, and will be "
724
725
726
727
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
728
729
730
731
732
733

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
734

735
            model = _LazyRegisteredModel(*split_str)
736
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
737
            model = _RegisteredModel.from_model_cls(model_cls)
738
        else:
739
740
741
742
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
743
            raise TypeError(msg)
744

745
        self.models[model_arch] = model
746

747
    def _raise_for_unsupported(self, architectures: list[str]):
748
        all_supported_archs = self.get_supported_archs()
749

750
751
752
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
753
754
                "to be inspected. Please check the logs for more details."
            )
755

756
757
758
759
760
761
762
763
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
764
765
                    "use this model architecture."
                )
766

767
768
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
769
770
            f"Supported architectures: {all_supported_archs}"
        )
771

772
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
773
774
        if model_arch not in self.models:
            return None
775

776
        return _try_load_model_cls(model_arch, self.models[model_arch])
777

778
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
779
780
        if model_arch not in self.models:
            return None
781

782
783
784
785
786
787
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
788
    ) -> str | None:
789
790
791
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

792
793
794
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
828
                if model_config.model_impl != "transformers":
829
830
831
832
833
834
835
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
836
837
                    "'auto_map' (relevant if the model is custom)."
                )
838
839

        if not model_module.is_backend_compatible():
840
            if model_config.model_impl != "transformers":
841
                return None
842

843
844
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
845
846
                "is not compatible with vLLM."
            )
847

848
        return model_config._get_transformers_backend_cls()
849

850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
875

876
877
    def inspect_model_cls(
        self,
878
        architectures: str | list[str],
879
        model_config: ModelConfig,
880
    ) -> tuple[_ModelInfo, str]:
881
882
        if isinstance(architectures, str):
            architectures = [architectures]
883
884
        if not architectures:
            raise ValueError("No model architectures are specified")
885
886

        # Require transformers impl
887
        if model_config.model_impl == "transformers":
888
            arch = self._try_resolve_transformers(architectures[0], model_config)
889
890
891
892
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
893
        elif model_config.model_impl == "terratorch":
894
895
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
896

897
        # Fallback to transformers impl (after resolving convert_type)
898
899
900
901
902
903
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
904
905
906
907
908
909
910
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
911
            model_info = self._try_inspect_model_cls(normalized_arch)
912
            if model_info is not None:
913
                return (model_info, arch)
914

915
        # Fallback to transformers impl (before resolving runner_type)
916
917
918
919
920
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
921
922
923
924
925
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

926
        return self._raise_for_unsupported(architectures)
927

928
929
    def resolve_model_cls(
        self,
930
        architectures: str | list[str],
931
        model_config: ModelConfig,
932
    ) -> tuple[type[nn.Module], str]:
933
934
        if isinstance(architectures, str):
            architectures = [architectures]
935
936
        if not architectures:
            raise ValueError("No model architectures are specified")
937
938

        # Require transformers impl
939
        if model_config.model_impl == "transformers":
940
            arch = self._try_resolve_transformers(architectures[0], model_config)
941
942
943
944
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
945
        elif model_config.model_impl == "terratorch":
946
947
948
949
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
950

951
        # Fallback to transformers impl (after resolving convert_type)
952
953
954
955
956
957
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
958
959
960
961
962
963
964
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
965
            model_cls = self._try_load_model_cls(normalized_arch)
966
967
            if model_cls is not None:
                return (model_cls, arch)
968

969
        # Fallback to transformers impl (before resolving runner_type)
970
971
972
973
974
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
975
976
977
978
979
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

980
        return self._raise_for_unsupported(architectures)
981

982
983
    def is_text_generation_model(
        self,
984
        architectures: str | list[str],
985
        model_config: ModelConfig,
986
    ) -> bool:
987
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
988
        return model_cls.is_text_generation_model
989

990
    def is_pooling_model(
991
        self,
992
        architectures: str | list[str],
993
        model_config: ModelConfig,
994
    ) -> bool:
995
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
996
        return model_cls.is_pooling_model
997

998
999
    def is_cross_encoder_model(
        self,
1000
        architectures: str | list[str],
1001
        model_config: ModelConfig,
1002
    ) -> bool:
1003
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1004
        return model_cls.supports_cross_encoding
1005

1006
1007
    def is_multimodal_model(
        self,
1008
        architectures: str | list[str],
1009
        model_config: ModelConfig,
1010
    ) -> bool:
1011
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1012
        return model_cls.supports_multimodal
1013

1014
    def is_multimodal_raw_input_only_model(
1015
        self,
1016
        architectures: str | list[str],
1017
        model_config: ModelConfig,
1018
    ) -> bool:
1019
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1020
        return model_cls.supports_multimodal_raw_input_only
1021

1022
1023
    def is_pp_supported_model(
        self,
1024
        architectures: str | list[str],
1025
        model_config: ModelConfig,
1026
    ) -> bool:
1027
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1028
        return model_cls.supports_pp
1029

1030
1031
    def model_has_inner_state(
        self,
1032
        architectures: str | list[str],
1033
        model_config: ModelConfig,
1034
    ) -> bool:
1035
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1036
        return model_cls.has_inner_state
1037

1038
1039
    def is_attention_free_model(
        self,
1040
        architectures: str | list[str],
1041
        model_config: ModelConfig,
1042
    ) -> bool:
1043
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1044
        return model_cls.is_attention_free
1045

1046
1047
    def is_hybrid_model(
        self,
1048
        architectures: str | list[str],
1049
        model_config: ModelConfig,
1050
    ) -> bool:
1051
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1052
1053
        return model_cls.is_hybrid

1054
1055
    def is_noops_model(
        self,
1056
        architectures: str | list[str],
1057
        model_config: ModelConfig,
1058
    ) -> bool:
1059
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1060
1061
        return model_cls.has_noops

1062
1063
    def is_transcription_model(
        self,
1064
        architectures: str | list[str],
1065
        model_config: ModelConfig,
1066
    ) -> bool:
1067
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1068
1069
        return model_cls.supports_transcription

1070
1071
    def is_transcription_only_model(
        self,
1072
        architectures: str | list[str],
1073
        model_config: ModelConfig,
1074
    ) -> bool:
1075
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1076
1077
        return model_cls.supports_transcription_only

1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1088
1089
1090
1091
1092

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1093
1094
1095
1096
1097
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1098
        # `cloudpickle` allows pickling lambda functions directly
1099
        import cloudpickle
1100

1101
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1102
1103
1104

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1105
1106
1107
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1108
1109
1110
1111
1112
1113

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1114
1115
1116
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1117

1118
        with open(output_filepath, "rb") as f:
1119
1120
1121
1122
1123
1124
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1125

1126
1127
1128
1129
1130
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1131
1132
1133

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1134
1135
1136


if __name__ == "__main__":
1137
    _run()