registry.py 42.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import hashlib
8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import Callable, Optional, TypeVar, Union
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
from vllm.config import (ModelConfig, iter_architecture_defaults,
27
                         try_match_architecture_defaults)
28
from vllm.logger import init_logger
29
from vllm.logging_utils import logtime
30
31
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
32

33
34
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
35
36
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
37
                         supports_multimodal_raw_input_only, supports_pp,
38
                         supports_transcription, supports_v0_only)
39
40
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
41
42
43

logger = init_logger(__name__)

44
# yapf: disable
45
46
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
47
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
48
49
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
50
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
51
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
52
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
53
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
54
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
55
56
57
58
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
59
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
60
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
61
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
62
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
63
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
64
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
65
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
66
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
67
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
68
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
69
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
70
71
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
72
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
73
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
74
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
75
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
76
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
77
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
78
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
79
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
80
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
81
82
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
83
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
84
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
85
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
86
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
87
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
88
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
89
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
90
91
92
93
94
95
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
96
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
97
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
98
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
99
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
100
101
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
102
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
103
104
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
105
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
106
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
107
108
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
109
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
110
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
111
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
112
113
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
114
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
115
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
116
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
117
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
118
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
119
120
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
121
122
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
123
    "MotifForCausalLM": ("motif", "MotifForCausalLM"),
124
125
126
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
127
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
128
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
129
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
130
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
131
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
132
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
133
134
135
136
137
138
139
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
140
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
141
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
142
143
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
144
145
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
146
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
147
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
148
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
149
150
151
152
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
153
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
154
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
155
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
156
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
157
158
159
}

_EMBEDDING_MODELS = {
160
    # [Text-only]
161
    "BertModel": ("bert", "BertEmbeddingModel"),
162
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
163
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
164
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
165
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
166
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
167
    "GritLM": ("gritlm", "GritLM"),
168
169
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
170
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
171
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
172
    "LlamaModel": ("llama", "LlamaForCausalLM"),
173
174
175
176
177
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
178
    "MistralModel": ("llama", "LlamaForCausalLM"),
179
    "ModernBertModel": ("modernbert", "ModernBertModel"),
180
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
181
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
182
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
183
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
184
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
185
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
186
187
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
188
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
189
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
190
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
191
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
192
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
193
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
194
195
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
196
    # models for the time being.
197
198
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
199
200
}

201
202
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
203
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
204
205
206
207
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
208
209
210
211
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
212
    # [Auto-converted (see adapters.py)]
213
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
214
215
}

216
_MULTIMODAL_MODELS = {
217
    # [Decoder-only]
218
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
219
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
220
221
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
222
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
223
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
224
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
225
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
226
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
227
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
228
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
229
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
230
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
231
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
232
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
233
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
234
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
235
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Lyu Han's avatar
Lyu Han committed
236
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
237
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
238
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
239
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
240
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
241
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
242
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
243
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
244
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
245
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
246
247
248
249
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
250
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
251
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
252
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
253
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
254
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
255
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
256
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
257
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
258
    "Ovis": ("ovis", "Ovis"),
259
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
260
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
261
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
262
263
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
264
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
265
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
266
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
267
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
268
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
269
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
270
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
271
272
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
    "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),  # noqa: E501
273
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
274
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
275
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
276
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
277
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
278
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
279
    # [Encoder-decoder]
280
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
281
}
282
283

_SPECULATIVE_DECODING_MODELS = {
284
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
285
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
286
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
287
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
288
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
289
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
290
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
291
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
292
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
293
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
294
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
295
    "MedusaModel": ("medusa", "Medusa"),
296
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
297
298
299
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
300
}
301

302
_TRANSFORMERS_SUPPORTED_MODELS = {
303
304
305
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
306
307
308
309
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
310
    "TransformersModel": ("transformers", "TransformersModel"),
311
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
312
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
313
}
314
# yapf: enable
315

316
_VLLM_MODELS = {
317
    **_TEXT_GENERATION_MODELS,
318
    **_EMBEDDING_MODELS,
319
    **_CROSS_ENCODER_MODELS,
320
    **_MULTIMODAL_MODELS,
321
    **_SPECULATIVE_DECODING_MODELS,
322
323
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
324
325
}

326
327
328
329
330
331
332
333
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

334
335
336
337
338
339
340
341
342
343
344
_PREVIOUSLY_SUPPORTED_MODELS = {
    "Phi3SmallForCausalLM": "0.9.2",
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
345

346

347
348
@dataclass(frozen=True)
class _ModelInfo:
349
    architecture: str
350
    is_text_generation_model: bool
351
    is_pooling_model: bool
352
    default_pooling_type: str
353
    supports_cross_encoding: bool
354
    supports_multimodal: bool
355
    supports_multimodal_raw_input_only: bool
356
    supports_multimodal_encoder_tp_data: bool
357
    supports_pp: bool
358
359
    has_inner_state: bool
    is_attention_free: bool
360
    is_hybrid: bool
361
    has_noops: bool
362
    supports_transcription: bool
363
    supports_transcription_only: bool
364
    supports_v0_only: bool
365
366

    @staticmethod
367
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
368
        return _ModelInfo(
369
            architecture=model.__name__,
370
            is_text_generation_model=is_text_generation_model(model),
371
            is_pooling_model=is_pooling_model(model),
372
            default_pooling_type=get_default_pooling_type(model),
373
            supports_cross_encoding=supports_cross_encoding(model),
374
            supports_multimodal=supports_multimodal(model),
375
376
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
377
378
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
379
            supports_pp=supports_pp(model),
380
381
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
382
            is_hybrid=is_hybrid(model),
383
            supports_transcription=supports_transcription(model),
384
385
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
386
            supports_v0_only=supports_v0_only(model),
387
            has_noops=has_noops(model),
388
        )
389
390


391
class _BaseRegisteredModel(ABC):
392

393
394
395
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
396

397
    @abstractmethod
398
    def load_model_cls(self) -> type[nn.Module]:
399
        raise NotImplementedError
400
401


402
403
404
405
406
407
408
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
409
    model_cls: type[nn.Module]
410
411

    @staticmethod
412
    def from_model_cls(model_cls: type[nn.Module]):
413
414
415
416
417
418
419
420
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

421
    def load_model_cls(self) -> type[nn.Module]:
422
423
424
425
426
427
428
429
430
431
432
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

    def _load_modelinfo_from_cache(self,
                                   module_hash: str) -> _ModelInfo | None:
        try:
            try:
                modelinfo_path = self._get_cache_dir(
                ) / self._get_cache_filename()
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
                logger.debug(("Cached model info file "
                              "for class %s.%s not found"), self.module_name,
                             self.class_name)
                return None

            if mi_dict["hash"] != module_hash:
                logger.debug(("Cached model info file "
                              "for class %s.%s is stale"), self.module_name,
                             self.class_name)
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
            logger.exception(("Cached model info "
                              "for class %s.%s error. "), self.module_name,
                             self.class_name)
            return None

    def _save_modelinfo_to_cache(self, mi: _ModelInfo,
                                 module_hash: str) -> None:
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
            with atomic_writer(modelinfo_path, encoding='utf-8') as f:
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
487
    def inspect_model_cls(self) -> _ModelInfo:
488
489
        model_path = Path(
            __file__).parent / f"{self.module_name.split('.')[-1]}.py"
490
        module_hash = None
491

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        if model_path.exists():
            with open(model_path, "rb") as f:
                module_hash = hashlib.md5(f.read()).hexdigest()

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
                logger.debug(("Loaded model info "
                              "for class %s.%s from cache"), self.module_name,
                             self.class_name)
                return mi
            else:
                logger.debug(("Cache model info "
                              "for class %s.%s miss. "
                              "Loading model instead."), self.module_name,
                             self.class_name)
507
508
509

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
510
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
511
512
513
514
        logger.debug("Loaded model info for class %s.%s", self.module_name,
                     self.class_name)

        # save cache file
515
516
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
517
518

        return mi
519

520
    def load_model_cls(self) -> type[nn.Module]:
521
522
523
524
525
526
527
528
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
529
) -> Optional[type[nn.Module]]:
530
    from vllm.platforms import current_platform
531
    current_platform.verify_model_arch(model_arch)
532
533
534
535
536
537
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
538
539


540
541
542
543
544
545
546
547
548
549
550
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
551
552


553
554
555
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
556
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
557

558
    def get_supported_archs(self) -> Set[str]:
559
        return self.models.keys()
560

561
562
563
    def register_model(
        self,
        model_arch: str,
564
        model_cls: Union[type[nn.Module], str],
565
    ) -> None:
566
567
568
        """
        Register an external model to be used in vLLM.

569
        `model_cls` can be either:
570

571
        - A [`torch.nn.Module`][] class directly referencing the model.
572
        - A string in the format `<module>:<class>` which can be used to
573
574
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
575
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
576
        """
577
578
579
580
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

581
        if model_arch in self.models:
582
583
584
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
585
586
587
588
589
590
591
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
592

593
            model = _LazyRegisteredModel(*split_str)
594
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
595
            model = _RegisteredModel.from_model_cls(model_cls)
596
597
598
599
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
600

601
        self.models[model_arch] = model
602

603
    def _raise_for_unsupported(self, architectures: list[str]):
604
        all_supported_archs = self.get_supported_archs()
605

606
607
608
609
610
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

611
612
613
614
615
616
617
618
619
620
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

621
622
623
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
624

625
    def _try_load_model_cls(self,
626
                            model_arch: str) -> Optional[type[nn.Module]]:
627
628
        if model_arch not in self.models:
            return None
629

630
        return _try_load_model_cls(model_arch, self.models[model_arch])
631

632
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
633
634
        if model_arch not in self.models:
            return None
635

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
681
                if model_config.model_impl != "transformers":
682
683
684
685
686
687
688
689
690
691
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
692
            if model_config.model_impl != "transformers":
693
                return None
694

695
696
697
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
698

699
        return model_config._get_transformers_backend_cls()
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
726

727
728
    def inspect_model_cls(
        self,
729
        architectures: Union[str, list[str]],
730
        model_config: ModelConfig,
731
    ) -> tuple[_ModelInfo, str]:
732
733
        if isinstance(architectures, str):
            architectures = [architectures]
734
735
        if not architectures:
            raise ValueError("No model architectures are specified")
736
737

        # Require transformers impl
738
        if model_config.model_impl == "transformers":
739
740
741
742
743
744
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
745
        elif model_config.model_impl == "terratorch":
746
747
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
748

749
750
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
751
                and model_config.model_impl == "auto"
752
753
754
755
756
757
758
759
760
761
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
762
            model_info = self._try_inspect_model_cls(normalized_arch)
763
            if model_info is not None:
764
                return (model_info, arch)
765

766
767
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
768
                and model_config.model_impl == "auto"):
769
770
771
772
773
774
775
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

776
        return self._raise_for_unsupported(architectures)
777

778
779
    def resolve_model_cls(
        self,
780
        architectures: Union[str, list[str]],
781
        model_config: ModelConfig,
782
    ) -> tuple[type[nn.Module], str]:
783
784
        if isinstance(architectures, str):
            architectures = [architectures]
785
786
        if not architectures:
            raise ValueError("No model architectures are specified")
787
788

        # Require transformers impl
789
        if model_config.model_impl == "transformers":
790
791
792
793
794
795
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
796
        elif model_config.model_impl == "terratorch":
797
798
799
800
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
801

802
803
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
804
                and model_config.model_impl == "auto"
805
806
807
808
809
810
811
812
813
814
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
815
            model_cls = self._try_load_model_cls(normalized_arch)
816
817
            if model_cls is not None:
                return (model_cls, arch)
818

819
820
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
821
                and model_config.model_impl == "auto"):
822
823
824
825
826
827
828
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

829
        return self._raise_for_unsupported(architectures)
830

831
832
    def is_text_generation_model(
        self,
833
        architectures: Union[str, list[str]],
834
        model_config: ModelConfig,
835
    ) -> bool:
836
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
837
        return model_cls.is_text_generation_model
838

839
    def is_pooling_model(
840
        self,
841
        architectures: Union[str, list[str]],
842
        model_config: ModelConfig,
843
    ) -> bool:
844
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
845
        return model_cls.is_pooling_model
846

847
848
    def is_cross_encoder_model(
        self,
849
        architectures: Union[str, list[str]],
850
        model_config: ModelConfig,
851
    ) -> bool:
852
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
853
        return model_cls.supports_cross_encoding
854

855
856
    def is_multimodal_model(
        self,
857
        architectures: Union[str, list[str]],
858
        model_config: ModelConfig,
859
    ) -> bool:
860
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
861
        return model_cls.supports_multimodal
862

863
    def is_multimodal_raw_input_only_model(
864
865
        self,
        architectures: Union[str, list[str]],
866
        model_config: ModelConfig,
867
    ) -> bool:
868
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
869
        return model_cls.supports_multimodal_raw_input_only
870

871
872
    def is_pp_supported_model(
        self,
873
        architectures: Union[str, list[str]],
874
        model_config: ModelConfig,
875
    ) -> bool:
876
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
877
        return model_cls.supports_pp
878

879
880
    def model_has_inner_state(
        self,
881
        architectures: Union[str, list[str]],
882
        model_config: ModelConfig,
883
    ) -> bool:
884
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
885
        return model_cls.has_inner_state
886

887
888
    def is_attention_free_model(
        self,
889
        architectures: Union[str, list[str]],
890
        model_config: ModelConfig,
891
    ) -> bool:
892
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
893
        return model_cls.is_attention_free
894

895
896
    def is_hybrid_model(
        self,
897
        architectures: Union[str, list[str]],
898
        model_config: ModelConfig,
899
    ) -> bool:
900
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
901
902
        return model_cls.is_hybrid

903
904
    def is_noops_model(
        self,
905
        architectures: Union[str, list[str]],
906
        model_config: ModelConfig,
907
    ) -> bool:
908
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
909
910
        return model_cls.has_noops

911
912
    def is_transcription_model(
        self,
913
        architectures: Union[str, list[str]],
914
        model_config: ModelConfig,
915
    ) -> bool:
916
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
917
918
        return model_cls.supports_transcription

919
920
921
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
922
        model_config: ModelConfig,
923
    ) -> bool:
924
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
925
926
        return model_cls.supports_transcription_only

927
928
    def is_v1_compatible(
        self,
929
        architectures: Union[str, list[str]],
930
        model_config: ModelConfig,
931
    ) -> bool:
932
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
933
934
        return not model_cls.supports_v0_only

935
936

ModelRegistry = _ModelRegistry({
937
938
    model_arch:
    _LazyRegisteredModel(
939
940
941
942
943
944
945
946
947
948
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
949
950
951
952
953
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

954
        # `cloudpickle` allows pickling lambda functions directly
955
        import cloudpickle
956
        input_bytes = cloudpickle.dumps((fn, output_filepath))
957
958
959

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
960
961
962
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
963
964
965
966
967
968
969
970
971

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

972
        with open(output_filepath, "rb") as f:
973
974
975
976
977
978
979
980
981
982
983
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
984
985
986

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
987
988
989


if __name__ == "__main__":
990
    _run()