registry.py 43.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import Callable, Optional, TypeVar, Union
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
    supports_v0_only,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
60
61
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
62
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
63
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
64
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
65
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
66
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
67
68
69
70
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
71
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
72
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
73
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
74
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
75
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
76
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
77
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
78
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
79
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
81
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
82
83
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
84
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
85
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
86
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
87
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
88
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
89
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
90
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
91
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
92
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
93
94
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
95
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
96
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
97
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
98
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
99
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
101
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
102
103
104
105
106
107
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
108
109
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
110
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
111
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
112
113
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
114
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
115
116
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
117
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
118
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
119
120
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
121
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
122
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
123
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
124
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
125
126
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
127
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
128
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
129
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
130
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
131
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
132
133
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
134
135
136
137
138
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
139
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
140
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
141
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
142
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
143
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
144
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
145
146
147
148
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
152
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
153
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
154
155
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
156
157
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
158
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
159
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
160
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
161
162
163
164
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
165
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
166
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
167
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
168
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
169
170
171
}

_EMBEDDING_MODELS = {
172
    # [Text-only]
173
    "BertModel": ("bert", "BertEmbeddingModel"),
174
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
175
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
176
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
177
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
178
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
179
    "GritLM": ("gritlm", "GritLM"),
180
181
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
182
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
183
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
184
    "LlamaModel": ("llama", "LlamaForCausalLM"),
185
186
    **{
        # Multiple models share the same architecture, so we include them all
187
188
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
189
190
        if arch == "LlamaForCausalLM"
    },
191
    "MistralModel": ("llama", "LlamaForCausalLM"),
192
    "ModernBertModel": ("modernbert", "ModernBertModel"),
193
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
194
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
195
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
196
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
197
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
198
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
199
200
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
201
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
203
    # [Multimodal]
204
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
205
206
207
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
208
    ),
209
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
210
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
211
212
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
213
    # models for the time being.
214
215
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
216
217
}

218
219
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
220
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
221
222
223
224
225
226
227
228
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
229
230
231
232
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
233
234
235
236
237
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
238
    # [Auto-converted (see adapters.py)]
239
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
240
241
}

242
_MULTIMODAL_MODELS = {
243
    # [Decoder-only]
244
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
245
246
247
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
248
    ),
249
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
250
251
252
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
253
    ),
254
255
256
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
257
    ),
258
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
259
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
260
261
262
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
263
    ),
264
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
265
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
266
267
268
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
269
    ),
270
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
271
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
272
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
273
274
275
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
276
    ),
277
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
278
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
279
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
280
281
282
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
283
    ),
284
285
286
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
287
    ),
288
289
290
291
292
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
293
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
294
295
296
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
297
    ),
298
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
299
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
300
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
301
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
302
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
303
304
305
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
306
    ),
307
308
309
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
310
    ),
311
312
313
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
314
    ),
315
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
316
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
317
318
319
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
320
    ),
321
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
322
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
323
324
325
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
326
    ),
327
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
328
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
329
    "Ovis": ("ovis", "Ovis"),
330
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
331
332
333
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
334
    ),
335
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
336
337
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
338
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
339
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
340
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
341
342
343
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
344
    ),
345
346
347
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
348
    ),
349
350
351
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
352
    ),
353
354
355
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
356
    ),
357
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
358
359
360
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
361
    ),
362
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
363
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
364
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
365
366
367
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
368
    ),
369
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
370
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
371
    # [Encoder-decoder]
372
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
373
}
374
375

_SPECULATIVE_DECODING_MODELS = {
376
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
377
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
378
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
379
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
380
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
381
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
382
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
383
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
384
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
385
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
386
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
387
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
388
    "MedusaModel": ("medusa", "Medusa"),
389
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
390
391
392
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
393
}
394

395
_TRANSFORMERS_SUPPORTED_MODELS = {
396
397
398
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
399
400
401
402
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
403
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
404
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
405
    "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"),  # noqa: E501
406
407
408
    "TransformersMoEForMultimodalLM": (
        "transformers_moe",
        "TransformersMoEForMultimodalLM",
409
    ),
410
411
412
    "TransformersEmbeddingModel": (
        "transformers_pooling",
        "TransformersEmbeddingModel",
413
    ),
414
415
416
    "TransformersForSequenceClassification": (
        "transformers_pooling",
        "TransformersForSequenceClassification",
417
    ),
418
419
420
    "TransformersMoEForSequenceClassification": (
        "transformers_pooling",
        "TransformersMoEForSequenceClassification",
421
    ),
422
423
424
    "TransformersMoEEmbeddingModel": (
        "transformers_pooling",
        "TransformersMoEEmbeddingModel",
425
    ),
426
}
427

428
_VLLM_MODELS = {
429
    **_TEXT_GENERATION_MODELS,
430
    **_EMBEDDING_MODELS,
431
    **_CROSS_ENCODER_MODELS,
432
    **_MULTIMODAL_MODELS,
433
    **_SPECULATIVE_DECODING_MODELS,
434
435
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
436
437
}

438
439
440
441
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
442
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
443

444
_PREVIOUSLY_SUPPORTED_MODELS = {
445
    "MotifForCausalLM": "0.10.2",
446
    "Phi3SmallForCausalLM": "0.9.2",
447
    "Phi4FlashForCausalLM": "0.10.2",
448
449
450
451
452
453
454
455
456
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
457

458

459
460
@dataclass(frozen=True)
class _ModelInfo:
461
    architecture: str
462
    is_text_generation_model: bool
463
    is_pooling_model: bool
464
    default_pooling_type: str
465
    supports_cross_encoding: bool
466
    supports_multimodal: bool
467
    supports_multimodal_raw_input_only: bool
468
    supports_multimodal_encoder_tp_data: bool
469
    supports_pp: bool
470
471
    has_inner_state: bool
    is_attention_free: bool
472
    is_hybrid: bool
473
    has_noops: bool
474
    supports_transcription: bool
475
    supports_transcription_only: bool
476
    supports_v0_only: bool
477
478

    @staticmethod
479
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
480
        return _ModelInfo(
481
            architecture=model.__name__,
482
            is_text_generation_model=is_text_generation_model(model),
483
            is_pooling_model=is_pooling_model(model),
484
            default_pooling_type=get_default_pooling_type(model),
485
            supports_cross_encoding=supports_cross_encoding(model),
486
            supports_multimodal=supports_multimodal(model),
487
488
489
490
491
492
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
493
            supports_pp=supports_pp(model),
494
495
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
496
            is_hybrid=is_hybrid(model),
497
            supports_transcription=supports_transcription(model),
498
499
500
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
501
            supports_v0_only=supports_v0_only(model),
502
            has_noops=has_noops(model),
503
        )
504
505


506
507
508
509
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
510

511
    @abstractmethod
512
    def load_model_cls(self) -> type[nn.Module]:
513
        raise NotImplementedError
514
515


516
517
518
519
520
521
522
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
523
    model_cls: type[nn.Module]
524
525

    @staticmethod
526
    def from_model_cls(model_cls: type[nn.Module]):
527
528
529
530
531
532
533
534
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

535
    def load_model_cls(self) -> type[nn.Module]:
536
537
538
539
540
541
542
543
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
544

545
546
547
    module_name: str
    class_name: str

548
549
550
551
552
553
554
555
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

556
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
557
558
        try:
            try:
559
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
560
561
562
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
563
564
565
566
567
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
568
569
570
                return None

            if mi_dict["hash"] != module_hash:
571
572
573
574
575
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
576
577
578
579
580
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
581
582
583
584
585
            logger.exception(
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
586
587
            return None

588
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
589
590
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
591

592
593
594
595
596
597
598
599
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
600
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
601
602
603
604
605
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
606
    def inspect_model_cls(self) -> _ModelInfo:
607
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
608
        module_hash = None
609

610
611
        if model_path.exists():
            with open(model_path, "rb") as f:
612
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
613
614
615

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
616
617
618
619
620
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
621
622
                return mi
            else:
623
624
625
626
627
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
628
629
630

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
631
632
633
634
635
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
636
637

        # save cache file
638
639
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
640
641

        return mi
642

643
    def load_model_cls(self) -> type[nn.Module]:
644
645
646
647
648
649
650
651
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
652
) -> Optional[type[nn.Module]]:
653
    from vllm.platforms import current_platform
654

655
    current_platform.verify_model_arch(model_arch)
656
657
658
    try:
        return model.load_model_cls()
    except Exception:
659
        logger.exception("Error in loading model architecture '%s'", model_arch)
660
        return None
661
662


663
664
665
666
667
668
669
670
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
671
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
672
        return None
673
674


675
676
677
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
678
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
679

680
    def get_supported_archs(self) -> Set[str]:
681
        return self.models.keys()
682

683
684
685
    def register_model(
        self,
        model_arch: str,
686
        model_cls: Union[type[nn.Module], str],
687
    ) -> None:
688
689
690
        """
        Register an external model to be used in vLLM.

691
        `model_cls` can be either:
692

693
        - A [`torch.nn.Module`][] class directly referencing the model.
694
        - A string in the format `<module>:<class>` which can be used to
695
696
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
697
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
698
        """
699
700
701
702
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

703
        if model_arch in self.models:
704
705
            logger.warning(
                "Model architecture %s is already registered, and will be "
706
707
708
709
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
710
711
712
713
714
715

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
716

717
            model = _LazyRegisteredModel(*split_str)
718
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
719
            model = _RegisteredModel.from_model_cls(model_cls)
720
        else:
721
722
723
724
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
725
            raise TypeError(msg)
726

727
        self.models[model_arch] = model
728

729
    def _raise_for_unsupported(self, architectures: list[str]):
730
        all_supported_archs = self.get_supported_archs()
731

732
733
734
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
735
736
                "to be inspected. Please check the logs for more details."
            )
737

738
739
740
741
742
743
744
745
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
746
747
                    "use this model architecture."
                )
748

749
750
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
751
752
            f"Supported architectures: {all_supported_archs}"
        )
753

754
    def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]:
755
756
        if model_arch not in self.models:
            return None
757

758
        return _try_load_model_cls(model_arch, self.models[model_arch])
759

760
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
761
762
        if model_arch not in self.models:
            return None
763

764
765
766
767
768
769
770
771
772
773
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

774
775
776
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
810
                if model_config.model_impl != "transformers":
811
812
813
814
815
816
817
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
818
819
                    "'auto_map' (relevant if the model is custom)."
                )
820
821

        if not model_module.is_backend_compatible():
822
            if model_config.model_impl != "transformers":
823
                return None
824

825
826
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
827
828
                "is not compatible with vLLM."
            )
829

830
        return model_config._get_transformers_backend_cls()
831

832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
857

858
859
    def inspect_model_cls(
        self,
860
        architectures: Union[str, list[str]],
861
        model_config: ModelConfig,
862
    ) -> tuple[_ModelInfo, str]:
863
864
        if isinstance(architectures, str):
            architectures = [architectures]
865
866
        if not architectures:
            raise ValueError("No model architectures are specified")
867
868

        # Require transformers impl
869
        if model_config.model_impl == "transformers":
870
            arch = self._try_resolve_transformers(architectures[0], model_config)
871
872
873
874
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
875
        elif model_config.model_impl == "terratorch":
876
877
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
878

879
        # Fallback to transformers impl (after resolving convert_type)
880
881
882
883
884
885
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
886
887
888
889
890
891
892
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
893
            model_info = self._try_inspect_model_cls(normalized_arch)
894
            if model_info is not None:
895
                return (model_info, arch)
896

897
        # Fallback to transformers impl (before resolving runner_type)
898
899
900
901
902
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
903
904
905
906
907
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

908
        return self._raise_for_unsupported(architectures)
909

910
911
    def resolve_model_cls(
        self,
912
        architectures: Union[str, list[str]],
913
        model_config: ModelConfig,
914
    ) -> tuple[type[nn.Module], str]:
915
916
        if isinstance(architectures, str):
            architectures = [architectures]
917
918
        if not architectures:
            raise ValueError("No model architectures are specified")
919
920

        # Require transformers impl
921
        if model_config.model_impl == "transformers":
922
            arch = self._try_resolve_transformers(architectures[0], model_config)
923
924
925
926
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
927
        elif model_config.model_impl == "terratorch":
928
929
930
931
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
932

933
        # Fallback to transformers impl (after resolving convert_type)
934
935
936
937
938
939
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
940
941
942
943
944
945
946
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
947
            model_cls = self._try_load_model_cls(normalized_arch)
948
949
            if model_cls is not None:
                return (model_cls, arch)
950

951
        # Fallback to transformers impl (before resolving runner_type)
952
953
954
955
956
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
957
958
959
960
961
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

962
        return self._raise_for_unsupported(architectures)
963

964
965
    def is_text_generation_model(
        self,
966
        architectures: Union[str, list[str]],
967
        model_config: ModelConfig,
968
    ) -> bool:
969
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
970
        return model_cls.is_text_generation_model
971

972
    def is_pooling_model(
973
        self,
974
        architectures: Union[str, list[str]],
975
        model_config: ModelConfig,
976
    ) -> bool:
977
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
978
        return model_cls.is_pooling_model
979

980
981
    def is_cross_encoder_model(
        self,
982
        architectures: Union[str, list[str]],
983
        model_config: ModelConfig,
984
    ) -> bool:
985
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
986
        return model_cls.supports_cross_encoding
987

988
989
    def is_multimodal_model(
        self,
990
        architectures: Union[str, list[str]],
991
        model_config: ModelConfig,
992
    ) -> bool:
993
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
994
        return model_cls.supports_multimodal
995

996
    def is_multimodal_raw_input_only_model(
997
998
        self,
        architectures: Union[str, list[str]],
999
        model_config: ModelConfig,
1000
    ) -> bool:
1001
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1002
        return model_cls.supports_multimodal_raw_input_only
1003

1004
1005
    def is_pp_supported_model(
        self,
1006
        architectures: Union[str, list[str]],
1007
        model_config: ModelConfig,
1008
    ) -> bool:
1009
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1010
        return model_cls.supports_pp
1011

1012
1013
    def model_has_inner_state(
        self,
1014
        architectures: Union[str, list[str]],
1015
        model_config: ModelConfig,
1016
    ) -> bool:
1017
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1018
        return model_cls.has_inner_state
1019

1020
1021
    def is_attention_free_model(
        self,
1022
        architectures: Union[str, list[str]],
1023
        model_config: ModelConfig,
1024
    ) -> bool:
1025
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1026
        return model_cls.is_attention_free
1027

1028
1029
    def is_hybrid_model(
        self,
1030
        architectures: Union[str, list[str]],
1031
        model_config: ModelConfig,
1032
    ) -> bool:
1033
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1034
1035
        return model_cls.is_hybrid

1036
1037
    def is_noops_model(
        self,
1038
        architectures: Union[str, list[str]],
1039
        model_config: ModelConfig,
1040
    ) -> bool:
1041
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1042
1043
        return model_cls.has_noops

1044
1045
    def is_transcription_model(
        self,
1046
        architectures: Union[str, list[str]],
1047
        model_config: ModelConfig,
1048
    ) -> bool:
1049
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1050
1051
        return model_cls.supports_transcription

1052
1053
1054
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
1055
        model_config: ModelConfig,
1056
    ) -> bool:
1057
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1058
1059
        return model_cls.supports_transcription_only

1060
1061
    def is_v1_compatible(
        self,
1062
        architectures: Union[str, list[str]],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
1067
        return not model_cls.supports_v0_only

1068

1069
1070
1071
1072
1073
1074
1075
1076
1077
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1078
1079
1080
1081
1082

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1083
1084
1085
1086
1087
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1088
        # `cloudpickle` allows pickling lambda functions directly
1089
        import cloudpickle
1090

1091
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1092
1093
1094

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1095
1096
1097
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1098
1099
1100
1101
1102
1103

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1104
1105
1106
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1107

1108
        with open(output_filepath, "rb") as f:
1109
1110
1111
1112
1113
1114
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1115

1116
1117
1118
1119
1120
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1121
1122
1123

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1124
1125
1126


if __name__ == "__main__":
1127
    _run()