registry.py 43.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import Callable, Optional, TypeVar, Union
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
    supports_v0_only,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
65
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
66
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
67
68
69
70
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
71
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
72
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
73
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
74
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
75
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
76
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
77
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
78
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
79
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
81
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
82
83
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
84
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
85
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
86
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
87
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
88
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
89
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
90
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
92
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
122
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
123
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
124
125
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
126
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
127
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
128
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
129
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
130
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
131
132
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
133
134
135
136
137
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
138
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
139
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
140
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
141
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
142
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
143
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
145
146
147
148
149
150
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
151
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
152
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
153
154
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
155
156
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
157
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
158
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
159
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
160
161
162
163
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
164
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
165
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
166
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
167
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
168
169
170
}

_EMBEDDING_MODELS = {
171
    # [Text-only]
172
    "BertModel": ("bert", "BertEmbeddingModel"),
173
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
174
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
175
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
176
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
177
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
178
    "GritLM": ("gritlm", "GritLM"),
179
180
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
181
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
182
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
183
    "LlamaModel": ("llama", "LlamaForCausalLM"),
184
185
    **{
        # Multiple models share the same architecture, so we include them all
186
187
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
188
189
        if arch == "LlamaForCausalLM"
    },
190
    "MistralModel": ("llama", "LlamaForCausalLM"),
191
    "ModernBertModel": ("modernbert", "ModernBertModel"),
192
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
193
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
194
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
195
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
196
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
197
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
198
199
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
202
    # [Multimodal]
203
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
204
205
206
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
207
    ),
208
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
209
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
210
211
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
212
    # models for the time being.
213
214
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
215
216
}

217
218
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
219
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
220
221
222
223
224
225
226
227
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
228
229
230
231
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
232
233
234
235
236
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
237
    # [Auto-converted (see adapters.py)]
238
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
239
240
}

241
_MULTIMODAL_MODELS = {
242
    # [Decoder-only]
243
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
244
245
246
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
247
    ),
248
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
249
250
251
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
252
    ),
253
254
255
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
256
    ),
257
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
258
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
259
260
261
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
262
    ),
263
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
264
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
265
266
267
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
268
    ),
269
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
270
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
271
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
272
273
274
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
275
    ),
276
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
277
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
278
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
279
280
281
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
282
    ),
283
284
285
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
286
    ),
287
288
289
290
291
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
292
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
293
294
295
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
296
    ),
297
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
298
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
299
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
300
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
301
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
302
303
304
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
305
    ),
306
307
308
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
309
    ),
310
311
312
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
313
    ),
314
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
315
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
316
317
318
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
319
    ),
320
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
321
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
322
323
324
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
325
    ),
326
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
327
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
328
    "Ovis": ("ovis", "Ovis"),
329
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
330
331
332
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
333
    ),
334
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
335
336
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
337
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
338
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
339
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
340
341
342
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
343
    ),
344
345
346
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
347
    ),
348
349
350
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
351
    ),
352
353
354
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
355
    ),
356
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
357
358
359
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
360
    ),
361
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
362
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
363
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
364
365
366
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
367
    ),
368
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
369
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
370
    # [Encoder-decoder]
371
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
372
}
373
374

_SPECULATIVE_DECODING_MODELS = {
375
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
376
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
377
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
378
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
379
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
380
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
381
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
382
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
383
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
384
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
385
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
386
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
387
    "MedusaModel": ("medusa", "Medusa"),
388
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
389
390
391
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
392
}
393

394
_TRANSFORMERS_SUPPORTED_MODELS = {
395
396
397
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
398
399
400
401
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
402
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
403
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
404
    "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"),  # noqa: E501
405
406
407
    "TransformersMoEForMultimodalLM": (
        "transformers_moe",
        "TransformersMoEForMultimodalLM",
408
    ),
409
410
411
    "TransformersEmbeddingModel": (
        "transformers_pooling",
        "TransformersEmbeddingModel",
412
    ),
413
414
415
    "TransformersForSequenceClassification": (
        "transformers_pooling",
        "TransformersForSequenceClassification",
416
    ),
417
418
419
    "TransformersMoEForSequenceClassification": (
        "transformers_pooling",
        "TransformersMoEForSequenceClassification",
420
    ),
421
422
423
    "TransformersMoEEmbeddingModel": (
        "transformers_pooling",
        "TransformersMoEEmbeddingModel",
424
    ),
425
}
426

427
_VLLM_MODELS = {
428
    **_TEXT_GENERATION_MODELS,
429
    **_EMBEDDING_MODELS,
430
    **_CROSS_ENCODER_MODELS,
431
    **_MULTIMODAL_MODELS,
432
    **_SPECULATIVE_DECODING_MODELS,
433
434
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
435
436
}

437
438
439
440
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
441
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
442

443
_PREVIOUSLY_SUPPORTED_MODELS = {
444
    "MotifForCausalLM": "0.10.2",
445
    "Phi3SmallForCausalLM": "0.9.2",
446
    "Phi4FlashForCausalLM": "0.10.2",
447
448
449
450
451
452
453
454
455
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
456

457

458
459
@dataclass(frozen=True)
class _ModelInfo:
460
    architecture: str
461
    is_text_generation_model: bool
462
    is_pooling_model: bool
463
    default_pooling_type: str
464
    supports_cross_encoding: bool
465
    supports_multimodal: bool
466
    supports_multimodal_raw_input_only: bool
467
    supports_multimodal_encoder_tp_data: bool
468
    supports_pp: bool
469
470
    has_inner_state: bool
    is_attention_free: bool
471
    is_hybrid: bool
472
    has_noops: bool
473
    supports_transcription: bool
474
    supports_transcription_only: bool
475
    supports_v0_only: bool
476
477

    @staticmethod
478
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
479
        return _ModelInfo(
480
            architecture=model.__name__,
481
            is_text_generation_model=is_text_generation_model(model),
482
            is_pooling_model=is_pooling_model(model),
483
            default_pooling_type=get_default_pooling_type(model),
484
            supports_cross_encoding=supports_cross_encoding(model),
485
            supports_multimodal=supports_multimodal(model),
486
487
488
489
490
491
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
492
            supports_pp=supports_pp(model),
493
494
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
495
            is_hybrid=is_hybrid(model),
496
            supports_transcription=supports_transcription(model),
497
498
499
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
500
            supports_v0_only=supports_v0_only(model),
501
            has_noops=has_noops(model),
502
        )
503
504


505
506
507
508
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
509

510
    @abstractmethod
511
    def load_model_cls(self) -> type[nn.Module]:
512
        raise NotImplementedError
513
514


515
516
517
518
519
520
521
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
522
    model_cls: type[nn.Module]
523
524

    @staticmethod
525
    def from_model_cls(model_cls: type[nn.Module]):
526
527
528
529
530
531
532
533
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

534
    def load_model_cls(self) -> type[nn.Module]:
535
536
537
538
539
540
541
542
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
543

544
545
546
    module_name: str
    class_name: str

547
548
549
550
551
552
553
554
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

555
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
556
557
        try:
            try:
558
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
559
560
561
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
562
563
564
565
566
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
567
568
569
                return None

            if mi_dict["hash"] != module_hash:
570
571
572
573
574
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
575
576
577
578
579
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
580
581
582
583
584
            logger.exception(
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
585
586
            return None

587
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
588
589
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
590

591
592
593
594
595
596
597
598
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
599
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
600
601
602
603
604
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
605
    def inspect_model_cls(self) -> _ModelInfo:
606
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
607
        module_hash = None
608

609
610
611
612
613
614
        if model_path.exists():
            with open(model_path, "rb") as f:
                module_hash = hashlib.md5(f.read()).hexdigest()

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
615
616
617
618
619
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
620
621
                return mi
            else:
622
623
624
625
626
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
627
628
629

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
630
631
632
633
634
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
635
636

        # save cache file
637
638
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
639
640

        return mi
641

642
    def load_model_cls(self) -> type[nn.Module]:
643
644
645
646
647
648
649
650
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
651
) -> Optional[type[nn.Module]]:
652
    from vllm.platforms import current_platform
653

654
    current_platform.verify_model_arch(model_arch)
655
656
657
    try:
        return model.load_model_cls()
    except Exception:
658
        logger.exception("Error in loading model architecture '%s'", model_arch)
659
        return None
660
661


662
663
664
665
666
667
668
669
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
670
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
671
        return None
672
673


674
675
676
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
677
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
678

679
    def get_supported_archs(self) -> Set[str]:
680
        return self.models.keys()
681

682
683
684
    def register_model(
        self,
        model_arch: str,
685
        model_cls: Union[type[nn.Module], str],
686
    ) -> None:
687
688
689
        """
        Register an external model to be used in vLLM.

690
        `model_cls` can be either:
691

692
        - A [`torch.nn.Module`][] class directly referencing the model.
693
        - A string in the format `<module>:<class>` which can be used to
694
695
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
696
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
697
        """
698
699
700
701
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

702
        if model_arch in self.models:
703
704
            logger.warning(
                "Model architecture %s is already registered, and will be "
705
706
707
708
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
709
710
711
712
713
714

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
715

716
            model = _LazyRegisteredModel(*split_str)
717
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
718
            model = _RegisteredModel.from_model_cls(model_cls)
719
        else:
720
721
722
723
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
724
            raise TypeError(msg)
725

726
        self.models[model_arch] = model
727

728
    def _raise_for_unsupported(self, architectures: list[str]):
729
        all_supported_archs = self.get_supported_archs()
730

731
732
733
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
734
735
                "to be inspected. Please check the logs for more details."
            )
736

737
738
739
740
741
742
743
744
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
745
746
                    "use this model architecture."
                )
747

748
749
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
750
751
            f"Supported architectures: {all_supported_archs}"
        )
752

753
    def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]:
754
755
        if model_arch not in self.models:
            return None
756

757
        return _try_load_model_cls(model_arch, self.models[model_arch])
758

759
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
760
761
        if model_arch not in self.models:
            return None
762

763
764
765
766
767
768
769
770
771
772
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

773
774
775
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
809
                if model_config.model_impl != "transformers":
810
811
812
813
814
815
816
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
817
818
                    "'auto_map' (relevant if the model is custom)."
                )
819
820

        if not model_module.is_backend_compatible():
821
            if model_config.model_impl != "transformers":
822
                return None
823

824
825
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
826
827
                "is not compatible with vLLM."
            )
828

829
        return model_config._get_transformers_backend_cls()
830

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
856

857
858
    def inspect_model_cls(
        self,
859
        architectures: Union[str, list[str]],
860
        model_config: ModelConfig,
861
    ) -> tuple[_ModelInfo, str]:
862
863
        if isinstance(architectures, str):
            architectures = [architectures]
864
865
        if not architectures:
            raise ValueError("No model architectures are specified")
866
867

        # Require transformers impl
868
        if model_config.model_impl == "transformers":
869
            arch = self._try_resolve_transformers(architectures[0], model_config)
870
871
872
873
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
874
        elif model_config.model_impl == "terratorch":
875
876
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
877

878
        # Fallback to transformers impl (after resolving convert_type)
879
880
881
882
883
884
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
885
886
887
888
889
890
891
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
892
            model_info = self._try_inspect_model_cls(normalized_arch)
893
            if model_info is not None:
894
                return (model_info, arch)
895

896
        # Fallback to transformers impl (before resolving runner_type)
897
898
899
900
901
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
902
903
904
905
906
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

907
        return self._raise_for_unsupported(architectures)
908

909
910
    def resolve_model_cls(
        self,
911
        architectures: Union[str, list[str]],
912
        model_config: ModelConfig,
913
    ) -> tuple[type[nn.Module], str]:
914
915
        if isinstance(architectures, str):
            architectures = [architectures]
916
917
        if not architectures:
            raise ValueError("No model architectures are specified")
918
919

        # Require transformers impl
920
        if model_config.model_impl == "transformers":
921
            arch = self._try_resolve_transformers(architectures[0], model_config)
922
923
924
925
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
926
        elif model_config.model_impl == "terratorch":
927
928
929
930
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
931

932
        # Fallback to transformers impl (after resolving convert_type)
933
934
935
936
937
938
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
939
940
941
942
943
944
945
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
946
            model_cls = self._try_load_model_cls(normalized_arch)
947
948
            if model_cls is not None:
                return (model_cls, arch)
949

950
        # Fallback to transformers impl (before resolving runner_type)
951
952
953
954
955
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
956
957
958
959
960
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

961
        return self._raise_for_unsupported(architectures)
962

963
964
    def is_text_generation_model(
        self,
965
        architectures: Union[str, list[str]],
966
        model_config: ModelConfig,
967
    ) -> bool:
968
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
969
        return model_cls.is_text_generation_model
970

971
    def is_pooling_model(
972
        self,
973
        architectures: Union[str, list[str]],
974
        model_config: ModelConfig,
975
    ) -> bool:
976
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
977
        return model_cls.is_pooling_model
978

979
980
    def is_cross_encoder_model(
        self,
981
        architectures: Union[str, list[str]],
982
        model_config: ModelConfig,
983
    ) -> bool:
984
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
985
        return model_cls.supports_cross_encoding
986

987
988
    def is_multimodal_model(
        self,
989
        architectures: Union[str, list[str]],
990
        model_config: ModelConfig,
991
    ) -> bool:
992
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
993
        return model_cls.supports_multimodal
994

995
    def is_multimodal_raw_input_only_model(
996
997
        self,
        architectures: Union[str, list[str]],
998
        model_config: ModelConfig,
999
    ) -> bool:
1000
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1001
        return model_cls.supports_multimodal_raw_input_only
1002

1003
1004
    def is_pp_supported_model(
        self,
1005
        architectures: Union[str, list[str]],
1006
        model_config: ModelConfig,
1007
    ) -> bool:
1008
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1009
        return model_cls.supports_pp
1010

1011
1012
    def model_has_inner_state(
        self,
1013
        architectures: Union[str, list[str]],
1014
        model_config: ModelConfig,
1015
    ) -> bool:
1016
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1017
        return model_cls.has_inner_state
1018

1019
1020
    def is_attention_free_model(
        self,
1021
        architectures: Union[str, list[str]],
1022
        model_config: ModelConfig,
1023
    ) -> bool:
1024
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1025
        return model_cls.is_attention_free
1026

1027
1028
    def is_hybrid_model(
        self,
1029
        architectures: Union[str, list[str]],
1030
        model_config: ModelConfig,
1031
    ) -> bool:
1032
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1033
1034
        return model_cls.is_hybrid

1035
1036
    def is_noops_model(
        self,
1037
        architectures: Union[str, list[str]],
1038
        model_config: ModelConfig,
1039
    ) -> bool:
1040
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1041
1042
        return model_cls.has_noops

1043
1044
    def is_transcription_model(
        self,
1045
        architectures: Union[str, list[str]],
1046
        model_config: ModelConfig,
1047
    ) -> bool:
1048
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1049
1050
        return model_cls.supports_transcription

1051
1052
1053
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
1054
        model_config: ModelConfig,
1055
    ) -> bool:
1056
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1057
1058
        return model_cls.supports_transcription_only

1059
1060
    def is_v1_compatible(
        self,
1061
        architectures: Union[str, list[str]],
1062
        model_config: ModelConfig,
1063
    ) -> bool:
1064
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1065
1066
        return not model_cls.supports_v0_only

1067

1068
1069
1070
1071
1072
1073
1074
1075
1076
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1077
1078
1079
1080
1081

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1082
1083
1084
1085
1086
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1087
        # `cloudpickle` allows pickling lambda functions directly
1088
        import cloudpickle
1089

1090
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1091
1092
1093

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1094
1095
1096
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1097
1098
1099
1100
1101
1102

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1103
1104
1105
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1106

1107
        with open(output_filepath, "rb") as f:
1108
1109
1110
1111
1112
1113
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1114

1115
1116
1117
1118
1119
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1120
1121
1122

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1123
1124
1125


if __name__ == "__main__":
1126
    _run()