registry.py 42.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import hashlib
8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import Callable, Optional, TypeVar, Union
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
from vllm.config import (ModelConfig, iter_architecture_defaults,
27
                         try_match_architecture_defaults)
28
from vllm.logger import init_logger
29
from vllm.logging_utils import logtime
30
31
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
32

33
34
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
35
36
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
37
                         supports_multimodal_raw_input_only, supports_pp,
38
                         supports_transcription, supports_v0_only)
39
40
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
41
42
43

logger = init_logger(__name__)

44
# yapf: disable
45
46
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
47
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
48
49
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
50
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
51
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
52
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
53
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
54
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
55
56
57
58
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
59
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
60
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
61
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
62
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
63
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
64
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
65
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
66
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
67
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
68
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
69
70
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
71
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
72
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
73
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
74
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
75
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
76
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
77
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
78
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
79
80
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
81
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
82
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
83
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
84
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
85
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
86
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
87
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
88
89
90
91
92
93
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
94
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
95
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
96
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
97
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
98
99
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
100
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
101
102
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
103
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
104
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
105
106
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
107
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
108
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
109
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
110
111
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
112
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
113
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
114
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
115
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
116
117
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
118
119
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
120
    "MotifForCausalLM": ("motif", "MotifForCausalLM"),
121
122
123
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
124
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
125
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
126
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
127
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
128
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
129
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
130
131
132
133
134
135
136
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
137
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
138
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
139
140
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
141
142
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
143
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
144
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
145
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
146
147
148
149
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
150
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
151
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
152
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
153
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
154
155
156
}

_EMBEDDING_MODELS = {
157
    # [Text-only]
158
    "BertModel": ("bert", "BertEmbeddingModel"),
159
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
160
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
161
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
162
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
163
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
164
    "GritLM": ("gritlm", "GritLM"),
165
166
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
167
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
168
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
169
    "LlamaModel": ("llama", "LlamaForCausalLM"),
170
171
172
173
174
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
175
    "MistralModel": ("llama", "LlamaForCausalLM"),
176
    "ModernBertModel": ("modernbert", "ModernBertModel"),
177
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
178
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
179
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
180
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
181
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
182
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
183
184
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
185
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
186
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
187
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
188
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
189
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
190
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
191
192
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
193
    # models for the time being.
194
195
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
196
197
}

198
199
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
200
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
201
202
203
204
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
205
206
207
208
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
209
    # [Auto-converted (see adapters.py)]
210
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
211
212
}

213
_MULTIMODAL_MODELS = {
214
    # [Decoder-only]
215
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
216
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
217
218
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
219
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
220
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
221
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
222
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
223
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
224
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
225
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
226
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
227
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
228
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
229
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
230
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
231
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
232
    "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"),
Lyu Han's avatar
Lyu Han committed
233
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
234
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
235
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
236
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
237
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
238
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
239
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
240
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
241
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
242
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
243
244
245
246
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
247
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
248
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
249
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
250
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
251
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
252
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
253
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
254
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
255
    "Ovis": ("ovis", "Ovis"),
256
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
257
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
258
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
259
260
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
261
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
262
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
263
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
264
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
265
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
266
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
267
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
268
269
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
    "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),  # noqa: E501
270
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
271
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
272
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
273
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
274
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
275
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
276
    # [Encoder-decoder]
277
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
278
}
279
280

_SPECULATIVE_DECODING_MODELS = {
281
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
282
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
283
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
284
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
285
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
286
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
287
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
288
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
289
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
290
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
291
    "MedusaModel": ("medusa", "Medusa"),
292
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
293
294
295
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
296
}
297

298
_TRANSFORMERS_SUPPORTED_MODELS = {
299
300
301
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
302
303
304
305
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
306
    "TransformersModel": ("transformers", "TransformersModel"),
307
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
308
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
309
}
310
# yapf: enable
311

312
_VLLM_MODELS = {
313
    **_TEXT_GENERATION_MODELS,
314
    **_EMBEDDING_MODELS,
315
    **_CROSS_ENCODER_MODELS,
316
    **_MULTIMODAL_MODELS,
317
    **_SPECULATIVE_DECODING_MODELS,
318
319
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
320
321
}

322
323
324
325
326
327
328
329
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

330
331
332
333
334
335
336
337
338
339
340
_PREVIOUSLY_SUPPORTED_MODELS = {
    "Phi3SmallForCausalLM": "0.9.2",
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
341

342

343
344
@dataclass(frozen=True)
class _ModelInfo:
345
    architecture: str
346
    is_text_generation_model: bool
347
    is_pooling_model: bool
348
    default_pooling_type: str
349
    supports_cross_encoding: bool
350
    supports_multimodal: bool
351
    supports_multimodal_raw_input_only: bool
352
    supports_multimodal_encoder_tp_data: bool
353
    supports_pp: bool
354
355
    has_inner_state: bool
    is_attention_free: bool
356
    is_hybrid: bool
357
    has_noops: bool
358
    supports_transcription: bool
359
    supports_transcription_only: bool
360
    supports_v0_only: bool
361
362

    @staticmethod
363
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
364
        return _ModelInfo(
365
            architecture=model.__name__,
366
            is_text_generation_model=is_text_generation_model(model),
367
            is_pooling_model=is_pooling_model(model),
368
            default_pooling_type=get_default_pooling_type(model),
369
            supports_cross_encoding=supports_cross_encoding(model),
370
            supports_multimodal=supports_multimodal(model),
371
372
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
373
374
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
375
            supports_pp=supports_pp(model),
376
377
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
378
            is_hybrid=is_hybrid(model),
379
            supports_transcription=supports_transcription(model),
380
381
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
382
            supports_v0_only=supports_v0_only(model),
383
            has_noops=has_noops(model),
384
        )
385
386


387
class _BaseRegisteredModel(ABC):
388

389
390
391
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
392

393
    @abstractmethod
394
    def load_model_cls(self) -> type[nn.Module]:
395
        raise NotImplementedError
396
397


398
399
400
401
402
403
404
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
405
    model_cls: type[nn.Module]
406
407

    @staticmethod
408
    def from_model_cls(model_cls: type[nn.Module]):
409
410
411
412
413
414
415
416
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

417
    def load_model_cls(self) -> type[nn.Module]:
418
419
420
421
422
423
424
425
426
427
428
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

    def _load_modelinfo_from_cache(self,
                                   module_hash: str) -> _ModelInfo | None:
        try:
            try:
                modelinfo_path = self._get_cache_dir(
                ) / self._get_cache_filename()
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
                logger.debug(("Cached model info file "
                              "for class %s.%s not found"), self.module_name,
                             self.class_name)
                return None

            if mi_dict["hash"] != module_hash:
                logger.debug(("Cached model info file "
                              "for class %s.%s is stale"), self.module_name,
                             self.class_name)
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
            logger.exception(("Cached model info "
                              "for class %s.%s error. "), self.module_name,
                             self.class_name)
            return None

    def _save_modelinfo_to_cache(self, mi: _ModelInfo,
                                 module_hash: str) -> None:
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
            with atomic_writer(modelinfo_path, encoding='utf-8') as f:
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
483
    def inspect_model_cls(self) -> _ModelInfo:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        model_path = Path(
            __file__).parent / f"{self.module_name.split('.')[-1]}.py"

        assert model_path.exists(), \
            f"Model {self.module_name} expected to be on path {model_path}"
        with open(model_path, "rb") as f:
            module_hash = hashlib.md5(f.read()).hexdigest()

        mi = self._load_modelinfo_from_cache(module_hash)
        if mi is not None:
            logger.debug(("Loaded model info "
                          "for class %s.%s from cache"), self.module_name,
                         self.class_name)
            return mi
        else:
            logger.debug(("Cache model info "
                          "for class %s.%s miss. "
                          "Loading model instead."), self.module_name,
                         self.class_name)

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
506
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
507
508
509
510
511
512
513
        logger.debug("Loaded model info for class %s.%s", self.module_name,
                     self.class_name)

        # save cache file
        self._save_modelinfo_to_cache(mi, module_hash)

        return mi
514

515
    def load_model_cls(self) -> type[nn.Module]:
516
517
518
519
520
521
522
523
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
524
) -> Optional[type[nn.Module]]:
525
    from vllm.platforms import current_platform
526
    current_platform.verify_model_arch(model_arch)
527
528
529
530
531
532
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
533
534


535
536
537
538
539
540
541
542
543
544
545
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
546
547


548
549
550
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
551
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
552

553
    def get_supported_archs(self) -> Set[str]:
554
        return self.models.keys()
555

556
557
558
    def register_model(
        self,
        model_arch: str,
559
        model_cls: Union[type[nn.Module], str],
560
    ) -> None:
561
562
563
        """
        Register an external model to be used in vLLM.

564
        `model_cls` can be either:
565

566
        - A [`torch.nn.Module`][] class directly referencing the model.
567
        - A string in the format `<module>:<class>` which can be used to
568
569
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
570
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
571
        """
572
573
574
575
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

576
        if model_arch in self.models:
577
578
579
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
580
581
582
583
584
585
586
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
587

588
            model = _LazyRegisteredModel(*split_str)
589
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
590
            model = _RegisteredModel.from_model_cls(model_cls)
591
592
593
594
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
595

596
        self.models[model_arch] = model
597

598
    def _raise_for_unsupported(self, architectures: list[str]):
599
        all_supported_archs = self.get_supported_archs()
600

601
602
603
604
605
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

606
607
608
609
610
611
612
613
614
615
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

616
617
618
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
619

620
    def _try_load_model_cls(self,
621
                            model_arch: str) -> Optional[type[nn.Module]]:
622
623
        if model_arch not in self.models:
            return None
624

625
        return _try_load_model_cls(model_arch, self.models[model_arch])
626

627
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
628
629
        if model_arch not in self.models:
            return None
630

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
676
                if model_config.model_impl != "transformers":
677
678
679
680
681
682
683
684
685
686
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
687
            if model_config.model_impl != "transformers":
688
                return None
689

690
691
692
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
693

694
        return model_config._get_transformers_backend_cls()
695

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
721

722
723
    def inspect_model_cls(
        self,
724
        architectures: Union[str, list[str]],
725
        model_config: ModelConfig,
726
    ) -> tuple[_ModelInfo, str]:
727
728
        if isinstance(architectures, str):
            architectures = [architectures]
729
730
        if not architectures:
            raise ValueError("No model architectures are specified")
731
732

        # Require transformers impl
733
        if model_config.model_impl == "transformers":
734
735
736
737
738
739
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
740
        elif model_config.model_impl == "terratorch":
741
742
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
743

744
745
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
746
                and model_config.model_impl == "auto"
747
748
749
750
751
752
753
754
755
756
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
757
            model_info = self._try_inspect_model_cls(normalized_arch)
758
            if model_info is not None:
759
                return (model_info, arch)
760

761
762
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
763
                and model_config.model_impl == "auto"):
764
765
766
767
768
769
770
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

771
        return self._raise_for_unsupported(architectures)
772

773
774
    def resolve_model_cls(
        self,
775
        architectures: Union[str, list[str]],
776
        model_config: ModelConfig,
777
    ) -> tuple[type[nn.Module], str]:
778
779
        if isinstance(architectures, str):
            architectures = [architectures]
780
781
        if not architectures:
            raise ValueError("No model architectures are specified")
782
783

        # Require transformers impl
784
        if model_config.model_impl == "transformers":
785
786
787
788
789
790
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
791
        elif model_config.model_impl == "terratorch":
792
793
794
795
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
796

797
798
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
799
                and model_config.model_impl == "auto"
800
801
802
803
804
805
806
807
808
809
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
810
            model_cls = self._try_load_model_cls(normalized_arch)
811
812
            if model_cls is not None:
                return (model_cls, arch)
813

814
815
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
816
                and model_config.model_impl == "auto"):
817
818
819
820
821
822
823
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

824
        return self._raise_for_unsupported(architectures)
825

826
827
    def is_text_generation_model(
        self,
828
        architectures: Union[str, list[str]],
829
        model_config: ModelConfig,
830
    ) -> bool:
831
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
832
        return model_cls.is_text_generation_model
833

834
    def is_pooling_model(
835
        self,
836
        architectures: Union[str, list[str]],
837
        model_config: ModelConfig,
838
    ) -> bool:
839
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
840
        return model_cls.is_pooling_model
841

842
843
    def is_cross_encoder_model(
        self,
844
        architectures: Union[str, list[str]],
845
        model_config: ModelConfig,
846
    ) -> bool:
847
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
848
        return model_cls.supports_cross_encoding
849

850
851
    def is_multimodal_model(
        self,
852
        architectures: Union[str, list[str]],
853
        model_config: ModelConfig,
854
    ) -> bool:
855
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
856
        return model_cls.supports_multimodal
857

858
    def is_multimodal_raw_input_only_model(
859
860
        self,
        architectures: Union[str, list[str]],
861
        model_config: ModelConfig,
862
    ) -> bool:
863
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
864
        return model_cls.supports_multimodal_raw_input_only
865

866
867
    def is_pp_supported_model(
        self,
868
        architectures: Union[str, list[str]],
869
        model_config: ModelConfig,
870
    ) -> bool:
871
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
872
        return model_cls.supports_pp
873

874
875
    def model_has_inner_state(
        self,
876
        architectures: Union[str, list[str]],
877
        model_config: ModelConfig,
878
    ) -> bool:
879
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
880
        return model_cls.has_inner_state
881

882
883
    def is_attention_free_model(
        self,
884
        architectures: Union[str, list[str]],
885
        model_config: ModelConfig,
886
    ) -> bool:
887
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
888
        return model_cls.is_attention_free
889

890
891
    def is_hybrid_model(
        self,
892
        architectures: Union[str, list[str]],
893
        model_config: ModelConfig,
894
    ) -> bool:
895
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
896
897
        return model_cls.is_hybrid

898
899
    def is_noops_model(
        self,
900
        architectures: Union[str, list[str]],
901
        model_config: ModelConfig,
902
    ) -> bool:
903
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
904
905
        return model_cls.has_noops

906
907
    def is_transcription_model(
        self,
908
        architectures: Union[str, list[str]],
909
        model_config: ModelConfig,
910
    ) -> bool:
911
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
912
913
        return model_cls.supports_transcription

914
915
916
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
917
        model_config: ModelConfig,
918
    ) -> bool:
919
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
920
921
        return model_cls.supports_transcription_only

922
923
    def is_v1_compatible(
        self,
924
        architectures: Union[str, list[str]],
925
        model_config: ModelConfig,
926
    ) -> bool:
927
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
928
929
        return not model_cls.supports_v0_only

930
931

ModelRegistry = _ModelRegistry({
932
933
    model_arch:
    _LazyRegisteredModel(
934
935
936
937
938
939
940
941
942
943
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
944
945
946
947
948
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

949
        # `cloudpickle` allows pickling lambda functions directly
950
        import cloudpickle
951
        input_bytes = cloudpickle.dumps((fn, output_filepath))
952
953
954

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
955
956
957
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
958
959
960
961
962
963
964
965
966

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

967
        with open(output_filepath, "rb") as f:
968
969
970
971
972
973
974
975
976
977
978
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
979
980
981

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
982
983
984


if __name__ == "__main__":
985
    _run()