registry.py 43.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
    supports_v0_only,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
65
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
66
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
67
68
69
70
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
71
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
72
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
73
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
74
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
75
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
76
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
77
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
78
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
79
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
81
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
82
83
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
84
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
85
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
86
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
87
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
88
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
89
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
90
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
92
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
93
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
94
95
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
96
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
97
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
98
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
99
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
101
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
102
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
103
104
105
106
107
108
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
109
110
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
111
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
112
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
113
114
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
115
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
116
117
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
118
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
119
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
120
121
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
122
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
123
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
124
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
125
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
126
127
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
128
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
129
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
130
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
131
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
132
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
133
134
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
135
136
137
138
139
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
140
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
141
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
142
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
143
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
144
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
145
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
146
147
148
149
150
151
152
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
153
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
154
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
155
156
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
157
158
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
159
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
160
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
161
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
162
163
164
165
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
166
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
167
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
168
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
169
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
170
171
172
}

_EMBEDDING_MODELS = {
173
    # [Text-only]
174
    "BertModel": ("bert", "BertEmbeddingModel"),
175
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
176
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
177
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
178
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
179
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
180
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
181
    "GritLM": ("gritlm", "GritLM"),
182
183
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
184
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
185
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
186
    "LlamaModel": ("llama", "LlamaForCausalLM"),
187
188
    **{
        # Multiple models share the same architecture, so we include them all
189
190
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
191
192
        if arch == "LlamaForCausalLM"
    },
193
    "MistralModel": ("llama", "LlamaForCausalLM"),
194
    "ModernBertModel": ("modernbert", "ModernBertModel"),
195
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
196
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
197
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
198
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
199
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
200
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
201
202
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
203
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
204
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
205
    # [Multimodal]
206
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
207
208
209
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
210
    ),
211
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
212
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
213
214
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
215
    # models for the time being.
216
217
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
218
219
}

220
221
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
222
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
223
224
225
226
227
228
229
230
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
231
232
233
234
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
235
236
237
238
239
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
240
    # [Auto-converted (see adapters.py)]
241
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
242
243
}

244
_MULTIMODAL_MODELS = {
245
    # [Decoder-only]
246
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
247
248
249
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
250
    ),
251
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
252
253
254
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
255
    ),
256
257
258
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
259
    ),
260
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
261
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
262
263
264
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
265
    ),
266
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
267
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
268
269
270
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
271
    ),
272
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
273
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
274
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
275
276
277
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
278
    ),
279
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
280
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
281
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
282
283
284
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
285
    ),
286
287
288
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
289
    ),
290
291
292
293
294
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
295
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
296
297
298
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
299
    ),
300
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
301
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
302
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
303
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
304
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
305
306
307
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
308
    ),
309
310
311
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
312
    ),
313
314
315
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
316
    ),
317
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
318
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
319
320
321
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
322
    ),
323
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
324
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
325
326
327
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
328
    ),
329
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
330
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
331
    "Ovis": ("ovis", "Ovis"),
332
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
333
334
335
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
336
    ),
337
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
338
339
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
340
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
341
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
342
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
343
344
345
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
346
    ),
347
348
349
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
350
    ),
351
352
353
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
354
    ),
355
356
357
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
358
    ),
359
360
361
362
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
363
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
364
365
366
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
367
    ),
368
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
369
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
370
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
371
372
373
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
374
    ),
375
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
376
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
377
    # [Encoder-decoder]
378
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
379
}
380
381

_SPECULATIVE_DECODING_MODELS = {
382
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
383
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
384
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
385
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
386
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
387
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
388
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
389
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
390
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
391
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
392
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
393
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
394
    "MedusaModel": ("medusa", "Medusa"),
395
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
396
397
398
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
399
}
400

401
_TRANSFORMERS_SUPPORTED_MODELS = {
402
403
404
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
405
406
407
408
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
409
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
410
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
411
    "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"),  # noqa: E501
412
413
414
    "TransformersMoEForMultimodalLM": (
        "transformers_moe",
        "TransformersMoEForMultimodalLM",
415
    ),
416
417
418
    "TransformersEmbeddingModel": (
        "transformers_pooling",
        "TransformersEmbeddingModel",
419
    ),
420
421
422
    "TransformersForSequenceClassification": (
        "transformers_pooling",
        "TransformersForSequenceClassification",
423
    ),
424
425
426
    "TransformersMoEForSequenceClassification": (
        "transformers_pooling",
        "TransformersMoEForSequenceClassification",
427
    ),
428
429
430
    "TransformersMoEEmbeddingModel": (
        "transformers_pooling",
        "TransformersMoEEmbeddingModel",
431
    ),
432
}
433

434
_VLLM_MODELS = {
435
    **_TEXT_GENERATION_MODELS,
436
    **_EMBEDDING_MODELS,
437
    **_CROSS_ENCODER_MODELS,
438
    **_MULTIMODAL_MODELS,
439
    **_SPECULATIVE_DECODING_MODELS,
440
441
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
442
443
}

444
445
446
447
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
448
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
449

450
_PREVIOUSLY_SUPPORTED_MODELS = {
451
    "MotifForCausalLM": "0.10.2",
452
    "Phi3SmallForCausalLM": "0.9.2",
453
    "Phi4FlashForCausalLM": "0.10.2",
454
455
456
457
458
459
460
461
462
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
463

464

465
466
@dataclass(frozen=True)
class _ModelInfo:
467
    architecture: str
468
    is_text_generation_model: bool
469
    is_pooling_model: bool
470
    default_pooling_type: str
471
    supports_cross_encoding: bool
472
    supports_multimodal: bool
473
    supports_multimodal_raw_input_only: bool
474
    supports_multimodal_encoder_tp_data: bool
475
    supports_pp: bool
476
477
    has_inner_state: bool
    is_attention_free: bool
478
    is_hybrid: bool
479
    has_noops: bool
480
    supports_transcription: bool
481
    supports_transcription_only: bool
482
    supports_v0_only: bool
483
484

    @staticmethod
485
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
486
        return _ModelInfo(
487
            architecture=model.__name__,
488
            is_text_generation_model=is_text_generation_model(model),
489
            is_pooling_model=is_pooling_model(model),
490
            default_pooling_type=get_default_pooling_type(model),
491
            supports_cross_encoding=supports_cross_encoding(model),
492
            supports_multimodal=supports_multimodal(model),
493
494
495
496
497
498
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
499
            supports_pp=supports_pp(model),
500
501
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
502
            is_hybrid=is_hybrid(model),
503
            supports_transcription=supports_transcription(model),
504
505
506
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
507
            supports_v0_only=supports_v0_only(model),
508
            has_noops=has_noops(model),
509
        )
510
511


512
513
514
515
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
516

517
    @abstractmethod
518
    def load_model_cls(self) -> type[nn.Module]:
519
        raise NotImplementedError
520
521


522
523
524
525
526
527
528
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
529
    model_cls: type[nn.Module]
530
531

    @staticmethod
532
    def from_model_cls(model_cls: type[nn.Module]):
533
534
535
536
537
538
539
540
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

541
    def load_model_cls(self) -> type[nn.Module]:
542
543
544
545
546
547
548
549
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
550

551
552
553
    module_name: str
    class_name: str

554
555
556
557
558
559
560
561
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

562
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
563
564
        try:
            try:
565
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
566
567
568
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
569
570
571
572
573
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
574
575
576
                return None

            if mi_dict["hash"] != module_hash:
577
578
579
580
581
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
582
583
584
585
586
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
587
588
589
590
591
            logger.exception(
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
592
593
            return None

594
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
595
596
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
597

598
599
600
601
602
603
604
605
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
606
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
607
608
609
610
611
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
612
    def inspect_model_cls(self) -> _ModelInfo:
613
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
614
        module_hash = None
615

616
617
        if model_path.exists():
            with open(model_path, "rb") as f:
618
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
619
620
621

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
622
623
624
625
626
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
627
628
                return mi
            else:
629
630
631
632
633
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
634
635
636

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
637
638
639
640
641
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
642
643

        # save cache file
644
645
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
646
647

        return mi
648

649
    def load_model_cls(self) -> type[nn.Module]:
650
651
652
653
654
655
656
657
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
658
) -> type[nn.Module] | None:
659
    from vllm.platforms import current_platform
660

661
    current_platform.verify_model_arch(model_arch)
662
663
664
    try:
        return model.load_model_cls()
    except Exception:
665
        logger.exception("Error in loading model architecture '%s'", model_arch)
666
        return None
667
668


669
670
671
672
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
673
) -> _ModelInfo | None:
674
675
676
    try:
        return model.inspect_model_cls()
    except Exception:
677
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
678
        return None
679
680


681
682
683
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
684
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
685

686
    def get_supported_archs(self) -> Set[str]:
687
        return self.models.keys()
688

689
690
691
    def register_model(
        self,
        model_arch: str,
692
        model_cls: type[nn.Module] | str,
693
    ) -> None:
694
695
696
        """
        Register an external model to be used in vLLM.

697
        `model_cls` can be either:
698

699
        - A [`torch.nn.Module`][] class directly referencing the model.
700
        - A string in the format `<module>:<class>` which can be used to
701
702
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
703
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
704
        """
705
706
707
708
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

709
        if model_arch in self.models:
710
711
            logger.warning(
                "Model architecture %s is already registered, and will be "
712
713
714
715
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
716
717
718
719
720
721

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
722

723
            model = _LazyRegisteredModel(*split_str)
724
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
725
            model = _RegisteredModel.from_model_cls(model_cls)
726
        else:
727
728
729
730
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
731
            raise TypeError(msg)
732

733
        self.models[model_arch] = model
734

735
    def _raise_for_unsupported(self, architectures: list[str]):
736
        all_supported_archs = self.get_supported_archs()
737

738
739
740
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
741
742
                "to be inspected. Please check the logs for more details."
            )
743

744
745
746
747
748
749
750
751
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
752
753
                    "use this model architecture."
                )
754

755
756
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
757
758
            f"Supported architectures: {all_supported_archs}"
        )
759

760
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
761
762
        if model_arch not in self.models:
            return None
763

764
        return _try_load_model_cls(model_arch, self.models[model_arch])
765

766
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
767
768
        if model_arch not in self.models:
            return None
769

770
771
772
773
774
775
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
776
    ) -> str | None:
777
778
779
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

780
781
782
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
816
                if model_config.model_impl != "transformers":
817
818
819
820
821
822
823
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
824
825
                    "'auto_map' (relevant if the model is custom)."
                )
826
827

        if not model_module.is_backend_compatible():
828
            if model_config.model_impl != "transformers":
829
                return None
830

831
832
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
833
834
                "is not compatible with vLLM."
            )
835

836
        return model_config._get_transformers_backend_cls()
837

838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
863

864
865
    def inspect_model_cls(
        self,
866
        architectures: str | list[str],
867
        model_config: ModelConfig,
868
    ) -> tuple[_ModelInfo, str]:
869
870
        if isinstance(architectures, str):
            architectures = [architectures]
871
872
        if not architectures:
            raise ValueError("No model architectures are specified")
873
874

        # Require transformers impl
875
        if model_config.model_impl == "transformers":
876
            arch = self._try_resolve_transformers(architectures[0], model_config)
877
878
879
880
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
881
        elif model_config.model_impl == "terratorch":
882
883
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
884

885
        # Fallback to transformers impl (after resolving convert_type)
886
887
888
889
890
891
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
892
893
894
895
896
897
898
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
899
            model_info = self._try_inspect_model_cls(normalized_arch)
900
            if model_info is not None:
901
                return (model_info, arch)
902

903
        # Fallback to transformers impl (before resolving runner_type)
904
905
906
907
908
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
909
910
911
912
913
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

914
        return self._raise_for_unsupported(architectures)
915

916
917
    def resolve_model_cls(
        self,
918
        architectures: str | list[str],
919
        model_config: ModelConfig,
920
    ) -> tuple[type[nn.Module], str]:
921
922
        if isinstance(architectures, str):
            architectures = [architectures]
923
924
        if not architectures:
            raise ValueError("No model architectures are specified")
925
926

        # Require transformers impl
927
        if model_config.model_impl == "transformers":
928
            arch = self._try_resolve_transformers(architectures[0], model_config)
929
930
931
932
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
933
        elif model_config.model_impl == "terratorch":
934
935
936
937
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
938

939
        # Fallback to transformers impl (after resolving convert_type)
940
941
942
943
944
945
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
946
947
948
949
950
951
952
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
953
            model_cls = self._try_load_model_cls(normalized_arch)
954
955
            if model_cls is not None:
                return (model_cls, arch)
956

957
        # Fallback to transformers impl (before resolving runner_type)
958
959
960
961
962
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
963
964
965
966
967
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

968
        return self._raise_for_unsupported(architectures)
969

970
971
    def is_text_generation_model(
        self,
972
        architectures: str | list[str],
973
        model_config: ModelConfig,
974
    ) -> bool:
975
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
976
        return model_cls.is_text_generation_model
977

978
    def is_pooling_model(
979
        self,
980
        architectures: str | list[str],
981
        model_config: ModelConfig,
982
    ) -> bool:
983
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
984
        return model_cls.is_pooling_model
985

986
987
    def is_cross_encoder_model(
        self,
988
        architectures: str | list[str],
989
        model_config: ModelConfig,
990
    ) -> bool:
991
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
992
        return model_cls.supports_cross_encoding
993

994
995
    def is_multimodal_model(
        self,
996
        architectures: str | list[str],
997
        model_config: ModelConfig,
998
    ) -> bool:
999
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1000
        return model_cls.supports_multimodal
1001

1002
    def is_multimodal_raw_input_only_model(
1003
        self,
1004
        architectures: str | list[str],
1005
        model_config: ModelConfig,
1006
    ) -> bool:
1007
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1008
        return model_cls.supports_multimodal_raw_input_only
1009

1010
1011
    def is_pp_supported_model(
        self,
1012
        architectures: str | list[str],
1013
        model_config: ModelConfig,
1014
    ) -> bool:
1015
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1016
        return model_cls.supports_pp
1017

1018
1019
    def model_has_inner_state(
        self,
1020
        architectures: str | list[str],
1021
        model_config: ModelConfig,
1022
    ) -> bool:
1023
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1024
        return model_cls.has_inner_state
1025

1026
1027
    def is_attention_free_model(
        self,
1028
        architectures: str | list[str],
1029
        model_config: ModelConfig,
1030
    ) -> bool:
1031
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1032
        return model_cls.is_attention_free
1033

1034
1035
    def is_hybrid_model(
        self,
1036
        architectures: str | list[str],
1037
        model_config: ModelConfig,
1038
    ) -> bool:
1039
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1040
1041
        return model_cls.is_hybrid

1042
1043
    def is_noops_model(
        self,
1044
        architectures: str | list[str],
1045
        model_config: ModelConfig,
1046
    ) -> bool:
1047
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1048
1049
        return model_cls.has_noops

1050
1051
    def is_transcription_model(
        self,
1052
        architectures: str | list[str],
1053
        model_config: ModelConfig,
1054
    ) -> bool:
1055
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1056
1057
        return model_cls.supports_transcription

1058
1059
    def is_transcription_only_model(
        self,
1060
        architectures: str | list[str],
1061
        model_config: ModelConfig,
1062
    ) -> bool:
1063
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1064
1065
        return model_cls.supports_transcription_only

1066
1067
    def is_v1_compatible(
        self,
1068
        architectures: str | list[str],
1069
        model_config: ModelConfig,
1070
    ) -> bool:
1071
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1072
1073
        return not model_cls.supports_v0_only

1074

1075
1076
1077
1078
1079
1080
1081
1082
1083
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1084
1085
1086
1087
1088

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1089
1090
1091
1092
1093
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1094
        # `cloudpickle` allows pickling lambda functions directly
1095
        import cloudpickle
1096

1097
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1098
1099
1100

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1101
1102
1103
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1104
1105
1106
1107
1108
1109

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1110
1111
1112
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1113

1114
        with open(output_filepath, "rb") as f:
1115
1116
1117
1118
1119
1120
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1121

1122
1123
1124
1125
1126
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1127
1128
1129

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1130
1131
1132


if __name__ == "__main__":
1133
    _run()