"vscode:/vscode.git/clone" did not exist on "fdc5b43d2017640a74f89c42ef61e1c79b4ffdd3"
registry.py 43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import Callable, Optional, TypeVar, Union
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
    supports_v0_only,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
65
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
66
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
67
68
69
70
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
71
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
72
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
73
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
74
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
75
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
76
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
77
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
78
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
79
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
81
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
82
83
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
84
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
85
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
86
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
87
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
88
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
89
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
90
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
92
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
122
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
123
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
124
125
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
126
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
127
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
128
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
129
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
130
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
131
132
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
133
134
135
136
137
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
138
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
139
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
140
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
141
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
142
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
143
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
145
146
147
148
149
150
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
151
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
152
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
153
154
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
155
156
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
157
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
158
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
159
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
160
161
162
163
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
164
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
165
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
166
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
167
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
168
169
170
}

_EMBEDDING_MODELS = {
171
    # [Text-only]
172
    "BertModel": ("bert", "BertEmbeddingModel"),
173
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
174
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
175
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
176
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
177
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
178
    "GritLM": ("gritlm", "GritLM"),
179
180
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
181
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
182
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
183
    "LlamaModel": ("llama", "LlamaForCausalLM"),
184
185
    **{
        # Multiple models share the same architecture, so we include them all
186
187
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
188
189
        if arch == "LlamaForCausalLM"
    },
190
    "MistralModel": ("llama", "LlamaForCausalLM"),
191
    "ModernBertModel": ("modernbert", "ModernBertModel"),
192
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
193
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
194
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
195
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
196
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
197
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
198
199
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
202
    # [Multimodal]
203
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
204
205
206
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
207
    ),
208
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
209
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
210
211
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
212
    # models for the time being.
213
214
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
215
216
}

217
218
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
219
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
220
221
222
223
224
225
226
227
228
229
230
231
232
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
233
    # [Auto-converted (see adapters.py)]
234
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
235
236
}

237
_MULTIMODAL_MODELS = {
238
    # [Decoder-only]
239
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
240
241
242
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
243
    ),
244
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
245
246
247
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
248
    ),
249
250
251
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
252
    ),
253
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
254
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
255
256
257
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
258
    ),
259
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
260
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
261
262
263
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
264
    ),
265
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
266
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
267
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
268
269
270
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
271
    ),
272
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
273
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
274
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
275
276
277
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
278
    ),
279
280
281
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
282
    ),
283
284
285
286
287
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
288
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
289
290
291
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
292
    ),
293
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
294
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
295
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
296
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
297
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
298
299
300
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
301
    ),
302
303
304
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
305
    ),
306
307
308
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
309
    ),
310
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
311
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
312
313
314
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
315
    ),
316
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
317
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
318
319
320
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
321
    ),
322
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
323
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
324
    "Ovis": ("ovis", "Ovis"),
325
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
326
327
328
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
329
    ),
330
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
331
332
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
333
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
334
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
335
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
336
337
338
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
339
    ),
340
341
342
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
343
    ),
344
345
346
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
347
    ),
348
349
350
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
351
    ),
352
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
353
354
355
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
356
    ),
357
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
358
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
359
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
360
361
362
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
363
    ),
364
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
365
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
366
    # [Encoder-decoder]
367
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
368
}
369
370

_SPECULATIVE_DECODING_MODELS = {
371
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
372
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
373
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
374
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
375
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
376
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
377
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
378
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
379
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
380
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
381
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
382
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
383
    "MedusaModel": ("medusa", "Medusa"),
384
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
385
386
387
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
388
}
389

390
_TRANSFORMERS_SUPPORTED_MODELS = {
391
392
393
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
394
395
396
397
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
398
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
399
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
400
    "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"),  # noqa: E501
401
402
403
    "TransformersMoEForMultimodalLM": (
        "transformers_moe",
        "TransformersMoEForMultimodalLM",
404
    ),
405
406
407
    "TransformersEmbeddingModel": (
        "transformers_pooling",
        "TransformersEmbeddingModel",
408
    ),
409
410
411
    "TransformersForSequenceClassification": (
        "transformers_pooling",
        "TransformersForSequenceClassification",
412
    ),
413
414
415
    "TransformersMoEForSequenceClassification": (
        "transformers_pooling",
        "TransformersMoEForSequenceClassification",
416
    ),
417
418
419
    "TransformersMoEEmbeddingModel": (
        "transformers_pooling",
        "TransformersMoEEmbeddingModel",
420
    ),
421
}
422

423
_VLLM_MODELS = {
424
    **_TEXT_GENERATION_MODELS,
425
    **_EMBEDDING_MODELS,
426
    **_CROSS_ENCODER_MODELS,
427
    **_MULTIMODAL_MODELS,
428
    **_SPECULATIVE_DECODING_MODELS,
429
430
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
431
432
}

433
434
435
436
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
437
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
438

439
_PREVIOUSLY_SUPPORTED_MODELS = {
440
    "MotifForCausalLM": "0.10.2",
441
    "Phi3SmallForCausalLM": "0.9.2",
442
    "Phi4FlashForCausalLM": "0.10.2",
443
444
445
446
447
448
449
450
451
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
452

453

454
455
@dataclass(frozen=True)
class _ModelInfo:
456
    architecture: str
457
    is_text_generation_model: bool
458
    is_pooling_model: bool
459
    default_pooling_type: str
460
    supports_cross_encoding: bool
461
    supports_multimodal: bool
462
    supports_multimodal_raw_input_only: bool
463
    supports_multimodal_encoder_tp_data: bool
464
    supports_pp: bool
465
466
    has_inner_state: bool
    is_attention_free: bool
467
    is_hybrid: bool
468
    has_noops: bool
469
    supports_transcription: bool
470
    supports_transcription_only: bool
471
    supports_v0_only: bool
472
473

    @staticmethod
474
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
475
        return _ModelInfo(
476
            architecture=model.__name__,
477
            is_text_generation_model=is_text_generation_model(model),
478
            is_pooling_model=is_pooling_model(model),
479
            default_pooling_type=get_default_pooling_type(model),
480
            supports_cross_encoding=supports_cross_encoding(model),
481
            supports_multimodal=supports_multimodal(model),
482
483
484
485
486
487
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
488
            supports_pp=supports_pp(model),
489
490
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
491
            is_hybrid=is_hybrid(model),
492
            supports_transcription=supports_transcription(model),
493
494
495
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
496
            supports_v0_only=supports_v0_only(model),
497
            has_noops=has_noops(model),
498
        )
499
500


501
502
503
504
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
505

506
    @abstractmethod
507
    def load_model_cls(self) -> type[nn.Module]:
508
        raise NotImplementedError
509
510


511
512
513
514
515
516
517
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
518
    model_cls: type[nn.Module]
519
520

    @staticmethod
521
    def from_model_cls(model_cls: type[nn.Module]):
522
523
524
525
526
527
528
529
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

530
    def load_model_cls(self) -> type[nn.Module]:
531
532
533
534
535
536
537
538
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
539

540
541
542
    module_name: str
    class_name: str

543
544
545
546
547
548
549
550
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

551
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
552
553
        try:
            try:
554
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
555
556
557
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
558
559
560
561
562
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
563
564
565
                return None

            if mi_dict["hash"] != module_hash:
566
567
568
569
570
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
571
572
573
574
575
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
576
577
578
579
580
            logger.exception(
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
581
582
            return None

583
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
584
585
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
586

587
588
589
590
591
592
593
594
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
595
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
596
597
598
599
600
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
601
    def inspect_model_cls(self) -> _ModelInfo:
602
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
603
        module_hash = None
604

605
606
607
608
609
610
        if model_path.exists():
            with open(model_path, "rb") as f:
                module_hash = hashlib.md5(f.read()).hexdigest()

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
611
612
613
614
615
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
616
617
                return mi
            else:
618
619
620
621
622
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
623
624
625

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
626
627
628
629
630
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
631
632

        # save cache file
633
634
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
635
636

        return mi
637

638
    def load_model_cls(self) -> type[nn.Module]:
639
640
641
642
643
644
645
646
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
647
) -> Optional[type[nn.Module]]:
648
    from vllm.platforms import current_platform
649

650
    current_platform.verify_model_arch(model_arch)
651
652
653
    try:
        return model.load_model_cls()
    except Exception:
654
        logger.exception("Error in loading model architecture '%s'", model_arch)
655
        return None
656
657


658
659
660
661
662
663
664
665
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
666
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
667
        return None
668
669


670
671
672
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
673
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
674

675
    def get_supported_archs(self) -> Set[str]:
676
        return self.models.keys()
677

678
679
680
    def register_model(
        self,
        model_arch: str,
681
        model_cls: Union[type[nn.Module], str],
682
    ) -> None:
683
684
685
        """
        Register an external model to be used in vLLM.

686
        `model_cls` can be either:
687

688
        - A [`torch.nn.Module`][] class directly referencing the model.
689
        - A string in the format `<module>:<class>` which can be used to
690
691
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
692
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
693
        """
694
695
696
697
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

698
        if model_arch in self.models:
699
700
            logger.warning(
                "Model architecture %s is already registered, and will be "
701
702
703
704
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
705
706
707
708
709
710

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
711

712
            model = _LazyRegisteredModel(*split_str)
713
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
714
            model = _RegisteredModel.from_model_cls(model_cls)
715
        else:
716
717
718
719
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
720
            raise TypeError(msg)
721

722
        self.models[model_arch] = model
723

724
    def _raise_for_unsupported(self, architectures: list[str]):
725
        all_supported_archs = self.get_supported_archs()
726

727
728
729
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
730
731
                "to be inspected. Please check the logs for more details."
            )
732

733
734
735
736
737
738
739
740
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
741
742
                    "use this model architecture."
                )
743

744
745
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
746
747
            f"Supported architectures: {all_supported_archs}"
        )
748

749
    def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]:
750
751
        if model_arch not in self.models:
            return None
752

753
        return _try_load_model_cls(model_arch, self.models[model_arch])
754

755
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
756
757
        if model_arch not in self.models:
            return None
758

759
760
761
762
763
764
765
766
767
768
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

769
770
771
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
805
                if model_config.model_impl != "transformers":
806
807
808
809
810
811
812
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
813
814
                    "'auto_map' (relevant if the model is custom)."
                )
815
816

        if not model_module.is_backend_compatible():
817
            if model_config.model_impl != "transformers":
818
                return None
819

820
821
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
822
823
                "is not compatible with vLLM."
            )
824

825
        return model_config._get_transformers_backend_cls()
826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
852

853
854
    def inspect_model_cls(
        self,
855
        architectures: Union[str, list[str]],
856
        model_config: ModelConfig,
857
    ) -> tuple[_ModelInfo, str]:
858
859
        if isinstance(architectures, str):
            architectures = [architectures]
860
861
        if not architectures:
            raise ValueError("No model architectures are specified")
862
863

        # Require transformers impl
864
        if model_config.model_impl == "transformers":
865
            arch = self._try_resolve_transformers(architectures[0], model_config)
866
867
868
869
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
870
        elif model_config.model_impl == "terratorch":
871
872
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
873

874
        # Fallback to transformers impl (after resolving convert_type)
875
876
877
878
879
880
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
881
882
883
884
885
886
887
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
888
            model_info = self._try_inspect_model_cls(normalized_arch)
889
            if model_info is not None:
890
                return (model_info, arch)
891

892
        # Fallback to transformers impl (before resolving runner_type)
893
894
895
896
897
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
898
899
900
901
902
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

903
        return self._raise_for_unsupported(architectures)
904

905
906
    def resolve_model_cls(
        self,
907
        architectures: Union[str, list[str]],
908
        model_config: ModelConfig,
909
    ) -> tuple[type[nn.Module], str]:
910
911
        if isinstance(architectures, str):
            architectures = [architectures]
912
913
        if not architectures:
            raise ValueError("No model architectures are specified")
914
915

        # Require transformers impl
916
        if model_config.model_impl == "transformers":
917
            arch = self._try_resolve_transformers(architectures[0], model_config)
918
919
920
921
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
922
        elif model_config.model_impl == "terratorch":
923
924
925
926
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
927

928
        # Fallback to transformers impl (after resolving convert_type)
929
930
931
932
933
934
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
935
936
937
938
939
940
941
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
942
            model_cls = self._try_load_model_cls(normalized_arch)
943
944
            if model_cls is not None:
                return (model_cls, arch)
945

946
        # Fallback to transformers impl (before resolving runner_type)
947
948
949
950
951
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
952
953
954
955
956
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

957
        return self._raise_for_unsupported(architectures)
958

959
960
    def is_text_generation_model(
        self,
961
        architectures: Union[str, list[str]],
962
        model_config: ModelConfig,
963
    ) -> bool:
964
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
965
        return model_cls.is_text_generation_model
966

967
    def is_pooling_model(
968
        self,
969
        architectures: Union[str, list[str]],
970
        model_config: ModelConfig,
971
    ) -> bool:
972
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
973
        return model_cls.is_pooling_model
974

975
976
    def is_cross_encoder_model(
        self,
977
        architectures: Union[str, list[str]],
978
        model_config: ModelConfig,
979
    ) -> bool:
980
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
981
        return model_cls.supports_cross_encoding
982

983
984
    def is_multimodal_model(
        self,
985
        architectures: Union[str, list[str]],
986
        model_config: ModelConfig,
987
    ) -> bool:
988
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
989
        return model_cls.supports_multimodal
990

991
    def is_multimodal_raw_input_only_model(
992
993
        self,
        architectures: Union[str, list[str]],
994
        model_config: ModelConfig,
995
    ) -> bool:
996
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
997
        return model_cls.supports_multimodal_raw_input_only
998

999
1000
    def is_pp_supported_model(
        self,
1001
        architectures: Union[str, list[str]],
1002
        model_config: ModelConfig,
1003
    ) -> bool:
1004
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1005
        return model_cls.supports_pp
1006

1007
1008
    def model_has_inner_state(
        self,
1009
        architectures: Union[str, list[str]],
1010
        model_config: ModelConfig,
1011
    ) -> bool:
1012
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1013
        return model_cls.has_inner_state
1014

1015
1016
    def is_attention_free_model(
        self,
1017
        architectures: Union[str, list[str]],
1018
        model_config: ModelConfig,
1019
    ) -> bool:
1020
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1021
        return model_cls.is_attention_free
1022

1023
1024
    def is_hybrid_model(
        self,
1025
        architectures: Union[str, list[str]],
1026
        model_config: ModelConfig,
1027
    ) -> bool:
1028
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1029
1030
        return model_cls.is_hybrid

1031
1032
    def is_noops_model(
        self,
1033
        architectures: Union[str, list[str]],
1034
        model_config: ModelConfig,
1035
    ) -> bool:
1036
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1037
1038
        return model_cls.has_noops

1039
1040
    def is_transcription_model(
        self,
1041
        architectures: Union[str, list[str]],
1042
        model_config: ModelConfig,
1043
    ) -> bool:
1044
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1045
1046
        return model_cls.supports_transcription

1047
1048
1049
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
1050
        model_config: ModelConfig,
1051
    ) -> bool:
1052
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1053
1054
        return model_cls.supports_transcription_only

1055
1056
    def is_v1_compatible(
        self,
1057
        architectures: Union[str, list[str]],
1058
        model_config: ModelConfig,
1059
    ) -> bool:
1060
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1061
1062
        return not model_cls.supports_v0_only

1063

1064
1065
1066
1067
1068
1069
1070
1071
1072
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1073
1074
1075
1076
1077

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1078
1079
1080
1081
1082
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1083
        # `cloudpickle` allows pickling lambda functions directly
1084
        import cloudpickle
1085

1086
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1087
1088
1089

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1090
1091
1092
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1093
1094
1095
1096
1097
1098

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1099
1100
1101
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1102

1103
        with open(output_filepath, "rb") as f:
1104
1105
1106
1107
1108
1109
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1110

1111
1112
1113
1114
1115
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1116
1117
1118

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1119
1120
1121


if __name__ == "__main__":
1122
    _run()