registry.py 43.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import hashlib
8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import Callable, Optional, TypeVar, Union
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
from vllm.config import (ModelConfig, iter_architecture_defaults,
27
                         try_match_architecture_defaults)
28
from vllm.logger import init_logger
29
from vllm.logging_utils import logtime
30
31
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
32

33
34
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
35
36
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
37
                         supports_multimodal_raw_input_only, supports_pp,
38
                         supports_transcription, supports_v0_only)
39
40
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
41
42
43

logger = init_logger(__name__)

44
# yapf: disable
45
46
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
47
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
48
49
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
50
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
51
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
52
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
53
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
54
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
55
56
57
58
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
59
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
60
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
61
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
62
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
63
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
64
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
65
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
66
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
67
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
68
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
69
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
70
71
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
72
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
73
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
74
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
75
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
76
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
77
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
78
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
79
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
80
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
81
82
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
83
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
84
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
85
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
86
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
87
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
88
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
89
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
90
91
92
93
94
95
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
96
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
97
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
98
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
99
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
100
101
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
102
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
103
104
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
105
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
106
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
107
108
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
109
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
110
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
111
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
112
113
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
114
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
115
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
116
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
117
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
118
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
119
120
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
121
122
123
124
125
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
126
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
127
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
128
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
129
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
130
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
131
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
132
133
134
135
136
137
138
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
139
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
140
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
141
142
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
143
144
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
145
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
146
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
147
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
148
149
150
151
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
152
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
153
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
154
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
155
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
156
157
158
}

_EMBEDDING_MODELS = {
159
    # [Text-only]
160
    "BertModel": ("bert", "BertEmbeddingModel"),
161
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
162
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
163
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
164
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
165
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
166
    "GritLM": ("gritlm", "GritLM"),
167
168
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
169
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
170
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
171
    "LlamaModel": ("llama", "LlamaForCausalLM"),
172
173
174
175
176
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
177
    "MistralModel": ("llama", "LlamaForCausalLM"),
178
    "ModernBertModel": ("modernbert", "ModernBertModel"),
179
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
180
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
181
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
182
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
183
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
184
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
185
186
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
187
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
188
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
189
    # [Multimodal]
190
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
Cyrus Leung's avatar
Cyrus Leung committed
191
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
192
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
193
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
194
195
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
196
    # models for the time being.
197
198
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
199
200
}

201
202
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
203
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
204
205
206
207
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
208
209
210
211
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
212
    # [Auto-converted (see adapters.py)]
213
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
214
215
}

216
_MULTIMODAL_MODELS = {
217
    # [Decoder-only]
218
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
219
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
220
221
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
222
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
223
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
224
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
225
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
226
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
227
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
228
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
229
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
230
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
231
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
232
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
233
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
234
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
235
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Lyu Han's avatar
Lyu Han committed
236
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
237
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
238
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
239
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
240
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
241
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
242
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
243
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
244
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
245
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
246
247
248
249
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
250
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
251
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
252
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
253
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
254
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
255
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
256
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
257
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
258
    "Ovis": ("ovis", "Ovis"),
259
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
260
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
261
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
262
263
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
264
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
265
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
266
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
267
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
268
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
269
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
270
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
271
272
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
    "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),  # noqa: E501
273
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
274
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
275
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
276
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
277
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
278
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
279
    # [Encoder-decoder]
280
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
281
}
282
283

_SPECULATIVE_DECODING_MODELS = {
284
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
285
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
286
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
287
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
288
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
289
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
290
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
291
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
292
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
293
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
294
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
295
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
296
    "MedusaModel": ("medusa", "Medusa"),
297
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
298
299
300
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
301
}
302

303
_TRANSFORMERS_SUPPORTED_MODELS = {
304
305
306
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
307
308
309
310
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
311
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
312
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
313
314
315
316
317
318
    "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"),  # noqa: E501
    "TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"),  # noqa: E501
    "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"),  # noqa: E501
    "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"),  # noqa: E501
    "TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"),  # noqa: E501
    "TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"),  # noqa: E501
319
}
320
# yapf: enable
321

322
_VLLM_MODELS = {
323
    **_TEXT_GENERATION_MODELS,
324
    **_EMBEDDING_MODELS,
325
    **_CROSS_ENCODER_MODELS,
326
    **_MULTIMODAL_MODELS,
327
    **_SPECULATIVE_DECODING_MODELS,
328
329
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
330
331
}

332
333
334
335
336
337
338
339
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

340
_PREVIOUSLY_SUPPORTED_MODELS = {
341
    "MotifForCausalLM": "0.10.2",
342
    "Phi3SmallForCausalLM": "0.9.2",
343
    "Phi4FlashForCausalLM": "0.10.2",
344
345
346
347
348
349
350
351
352
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
353

354

355
356
@dataclass(frozen=True)
class _ModelInfo:
357
    architecture: str
358
    is_text_generation_model: bool
359
    is_pooling_model: bool
360
    default_pooling_type: str
361
    supports_cross_encoding: bool
362
    supports_multimodal: bool
363
    supports_multimodal_raw_input_only: bool
364
    supports_multimodal_encoder_tp_data: bool
365
    supports_pp: bool
366
367
    has_inner_state: bool
    is_attention_free: bool
368
    is_hybrid: bool
369
    has_noops: bool
370
    supports_transcription: bool
371
    supports_transcription_only: bool
372
    supports_v0_only: bool
373
374

    @staticmethod
375
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
376
        return _ModelInfo(
377
            architecture=model.__name__,
378
            is_text_generation_model=is_text_generation_model(model),
379
            is_pooling_model=is_pooling_model(model),
380
            default_pooling_type=get_default_pooling_type(model),
381
            supports_cross_encoding=supports_cross_encoding(model),
382
            supports_multimodal=supports_multimodal(model),
383
384
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
385
386
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
387
            supports_pp=supports_pp(model),
388
389
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
390
            is_hybrid=is_hybrid(model),
391
            supports_transcription=supports_transcription(model),
392
393
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
394
            supports_v0_only=supports_v0_only(model),
395
            has_noops=has_noops(model),
396
        )
397
398


399
class _BaseRegisteredModel(ABC):
400

401
402
403
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
404

405
    @abstractmethod
406
    def load_model_cls(self) -> type[nn.Module]:
407
        raise NotImplementedError
408
409


410
411
412
413
414
415
416
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
417
    model_cls: type[nn.Module]
418
419

    @staticmethod
420
    def from_model_cls(model_cls: type[nn.Module]):
421
422
423
424
425
426
427
428
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

429
    def load_model_cls(self) -> type[nn.Module]:
430
431
432
433
434
435
436
437
438
439
440
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

    def _load_modelinfo_from_cache(self,
                                   module_hash: str) -> _ModelInfo | None:
        try:
            try:
                modelinfo_path = self._get_cache_dir(
                ) / self._get_cache_filename()
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
                logger.debug(("Cached model info file "
                              "for class %s.%s not found"), self.module_name,
                             self.class_name)
                return None

            if mi_dict["hash"] != module_hash:
                logger.debug(("Cached model info file "
                              "for class %s.%s is stale"), self.module_name,
                             self.class_name)
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
            logger.exception(("Cached model info "
                              "for class %s.%s error. "), self.module_name,
                             self.class_name)
            return None

    def _save_modelinfo_to_cache(self, mi: _ModelInfo,
                                 module_hash: str) -> None:
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
            with atomic_writer(modelinfo_path, encoding='utf-8') as f:
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
495
    def inspect_model_cls(self) -> _ModelInfo:
496
497
        model_path = Path(
            __file__).parent / f"{self.module_name.split('.')[-1]}.py"
498
        module_hash = None
499

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        if model_path.exists():
            with open(model_path, "rb") as f:
                module_hash = hashlib.md5(f.read()).hexdigest()

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
                logger.debug(("Loaded model info "
                              "for class %s.%s from cache"), self.module_name,
                             self.class_name)
                return mi
            else:
                logger.debug(("Cache model info "
                              "for class %s.%s miss. "
                              "Loading model instead."), self.module_name,
                             self.class_name)
515
516
517

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
518
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
519
520
521
522
        logger.debug("Loaded model info for class %s.%s", self.module_name,
                     self.class_name)

        # save cache file
523
524
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
525
526

        return mi
527

528
    def load_model_cls(self) -> type[nn.Module]:
529
530
531
532
533
534
535
536
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
537
) -> Optional[type[nn.Module]]:
538
    from vllm.platforms import current_platform
539
    current_platform.verify_model_arch(model_arch)
540
541
542
543
544
545
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
546
547


548
549
550
551
552
553
554
555
556
557
558
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
559
560


561
562
563
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
564
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
565

566
    def get_supported_archs(self) -> Set[str]:
567
        return self.models.keys()
568

569
570
571
    def register_model(
        self,
        model_arch: str,
572
        model_cls: Union[type[nn.Module], str],
573
    ) -> None:
574
575
576
        """
        Register an external model to be used in vLLM.

577
        `model_cls` can be either:
578

579
        - A [`torch.nn.Module`][] class directly referencing the model.
580
        - A string in the format `<module>:<class>` which can be used to
581
582
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
583
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
584
        """
585
586
587
588
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

589
        if model_arch in self.models:
590
591
592
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
593
594
595
596
597
598
599
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
600

601
            model = _LazyRegisteredModel(*split_str)
602
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
603
            model = _RegisteredModel.from_model_cls(model_cls)
604
605
606
607
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
608

609
        self.models[model_arch] = model
610

611
    def _raise_for_unsupported(self, architectures: list[str]):
612
        all_supported_archs = self.get_supported_archs()
613

614
615
616
617
618
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

619
620
621
622
623
624
625
626
627
628
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

629
630
631
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
632

633
    def _try_load_model_cls(self,
634
                            model_arch: str) -> Optional[type[nn.Module]]:
635
636
        if model_arch not in self.models:
            return None
637

638
        return _try_load_model_cls(model_arch, self.models[model_arch])
639

640
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
641
642
        if model_arch not in self.models:
            return None
643

644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
689
                if model_config.model_impl != "transformers":
690
691
692
693
694
695
696
697
698
699
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
700
            if model_config.model_impl != "transformers":
701
                return None
702

703
704
705
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
706

707
        return model_config._get_transformers_backend_cls()
708

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
734

735
736
    def inspect_model_cls(
        self,
737
        architectures: Union[str, list[str]],
738
        model_config: ModelConfig,
739
    ) -> tuple[_ModelInfo, str]:
740
741
        if isinstance(architectures, str):
            architectures = [architectures]
742
743
        if not architectures:
            raise ValueError("No model architectures are specified")
744
745

        # Require transformers impl
746
        if model_config.model_impl == "transformers":
747
748
749
750
751
752
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
753
        elif model_config.model_impl == "terratorch":
754
755
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
756

757
758
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
759
                and model_config.model_impl == "auto"
760
761
762
763
764
765
766
767
768
769
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
770
            model_info = self._try_inspect_model_cls(normalized_arch)
771
            if model_info is not None:
772
                return (model_info, arch)
773

774
775
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
776
                and model_config.model_impl == "auto"):
777
778
779
780
781
782
783
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

784
        return self._raise_for_unsupported(architectures)
785

786
787
    def resolve_model_cls(
        self,
788
        architectures: Union[str, list[str]],
789
        model_config: ModelConfig,
790
    ) -> tuple[type[nn.Module], str]:
791
792
        if isinstance(architectures, str):
            architectures = [architectures]
793
794
        if not architectures:
            raise ValueError("No model architectures are specified")
795
796

        # Require transformers impl
797
        if model_config.model_impl == "transformers":
798
799
800
801
802
803
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
804
        elif model_config.model_impl == "terratorch":
805
806
807
808
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
809

810
811
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
812
                and model_config.model_impl == "auto"
813
814
815
816
817
818
819
820
821
822
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
823
            model_cls = self._try_load_model_cls(normalized_arch)
824
825
            if model_cls is not None:
                return (model_cls, arch)
826

827
828
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
829
                and model_config.model_impl == "auto"):
830
831
832
833
834
835
836
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

837
        return self._raise_for_unsupported(architectures)
838

839
840
    def is_text_generation_model(
        self,
841
        architectures: Union[str, list[str]],
842
        model_config: ModelConfig,
843
    ) -> bool:
844
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
845
        return model_cls.is_text_generation_model
846

847
    def is_pooling_model(
848
        self,
849
        architectures: Union[str, list[str]],
850
        model_config: ModelConfig,
851
    ) -> bool:
852
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
853
        return model_cls.is_pooling_model
854

855
856
    def is_cross_encoder_model(
        self,
857
        architectures: Union[str, list[str]],
858
        model_config: ModelConfig,
859
    ) -> bool:
860
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
861
        return model_cls.supports_cross_encoding
862

863
864
    def is_multimodal_model(
        self,
865
        architectures: Union[str, list[str]],
866
        model_config: ModelConfig,
867
    ) -> bool:
868
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
869
        return model_cls.supports_multimodal
870

871
    def is_multimodal_raw_input_only_model(
872
873
        self,
        architectures: Union[str, list[str]],
874
        model_config: ModelConfig,
875
    ) -> bool:
876
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
877
        return model_cls.supports_multimodal_raw_input_only
878

879
880
    def is_pp_supported_model(
        self,
881
        architectures: Union[str, list[str]],
882
        model_config: ModelConfig,
883
    ) -> bool:
884
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
885
        return model_cls.supports_pp
886

887
888
    def model_has_inner_state(
        self,
889
        architectures: Union[str, list[str]],
890
        model_config: ModelConfig,
891
    ) -> bool:
892
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
893
        return model_cls.has_inner_state
894

895
896
    def is_attention_free_model(
        self,
897
        architectures: Union[str, list[str]],
898
        model_config: ModelConfig,
899
    ) -> bool:
900
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
901
        return model_cls.is_attention_free
902

903
904
    def is_hybrid_model(
        self,
905
        architectures: Union[str, list[str]],
906
        model_config: ModelConfig,
907
    ) -> bool:
908
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
909
910
        return model_cls.is_hybrid

911
912
    def is_noops_model(
        self,
913
        architectures: Union[str, list[str]],
914
        model_config: ModelConfig,
915
    ) -> bool:
916
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
917
918
        return model_cls.has_noops

919
920
    def is_transcription_model(
        self,
921
        architectures: Union[str, list[str]],
922
        model_config: ModelConfig,
923
    ) -> bool:
924
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
925
926
        return model_cls.supports_transcription

927
928
929
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
930
        model_config: ModelConfig,
931
    ) -> bool:
932
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
933
934
        return model_cls.supports_transcription_only

935
936
    def is_v1_compatible(
        self,
937
        architectures: Union[str, list[str]],
938
        model_config: ModelConfig,
939
    ) -> bool:
940
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
941
942
        return not model_cls.supports_v0_only

943
944

ModelRegistry = _ModelRegistry({
945
946
    model_arch:
    _LazyRegisteredModel(
947
948
949
950
951
952
953
954
955
956
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
957
958
959
960
961
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

962
        # `cloudpickle` allows pickling lambda functions directly
963
        import cloudpickle
964
        input_bytes = cloudpickle.dumps((fn, output_filepath))
965
966
967

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
968
969
970
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
971
972
973
974
975
976
977
978
979

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

980
        with open(output_filepath, "rb") as f:
981
982
983
984
985
986
987
988
989
990
991
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
992
993
994

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
995
996
997


if __name__ == "__main__":
998
    _run()