registry.py 44.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
60
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
61
62
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
63
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
64
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
65
66
67
68
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
69
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
70
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
71
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
72
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
73
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
74
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
75
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
77
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
78
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
79
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
80
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
81
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
82
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
84
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
85
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
86
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
87
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
88
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
89
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
90
91
92
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
93
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
94
95
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
96
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
97
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
98
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
99
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
101
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
102
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
103
104
105
106
107
108
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
109
110
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
111
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
112
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
113
114
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
115
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
116
117
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
118
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
119
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
120
121
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
122
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
123
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
124
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
125
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
126
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
127
128
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
129
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
130
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
131
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
132
133
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
134
135
136
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
137
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
138
139
140
141
142
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
143
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
144
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
145
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
146
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
147
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
152
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
153
154
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
155
156
157
158
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
159
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
160
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
161
162
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
163
164
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
165
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
166
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
167
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
168
169
170
171
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
172
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
173
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
174
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
175
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
176
177
178
}

_EMBEDDING_MODELS = {
179
    # [Text-only]
180
    "BertModel": ("bert", "BertEmbeddingModel"),
181
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
182
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
183
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
184
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
185
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
186
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
187
    "GritLM": ("gritlm", "GritLM"),
188
189
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
190
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
191
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
192
    "LlamaModel": ("llama", "LlamaForCausalLM"),
193
194
    **{
        # Multiple models share the same architecture, so we include them all
195
196
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
197
198
        if arch == "LlamaForCausalLM"
    },
199
    "MistralModel": ("llama", "LlamaForCausalLM"),
200
    "ModernBertModel": ("modernbert", "ModernBertModel"),
201
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
202
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
203
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
204
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
205
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
206
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
207
208
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
209
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
210
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
211
    # [Multimodal]
212
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
213
214
215
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
216
    ),
217
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
218
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
219
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
220
221
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
222
    # models for the time being.
223
224
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
225
226
}

227
228
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
229
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
230
231
232
233
234
235
236
237
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
238
239
240
241
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
242
243
244
245
246
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
247
    # [Auto-converted (see adapters.py)]
248
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
249
250
}

251
_MULTIMODAL_MODELS = {
252
    # [Decoder-only]
253
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
254
255
256
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
257
    ),
258
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
259
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
260
261
262
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
263
    ),
264
265
266
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
267
    ),
268
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
269
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
270
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
271
272
273
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
274
    ),
275
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
276
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
277
278
279
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
280
    ),
281
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
282
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
283
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
284
285
286
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
287
    ),
288
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
289
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
290
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
291
292
293
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
294
    ),
295
296
297
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
298
    ),
299
300
301
302
303
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
304
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
305
306
307
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
308
    ),
309
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
310
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
311
312
313
314
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
315
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
316
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
317
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
318
319
320
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
321
    ),
322
323
324
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
325
    ),
326
327
328
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
329
    ),
330
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
331
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
332
333
334
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
335
    ),
336
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
337
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
338
339
340
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
341
    ),
342
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
343
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
344
    "Ovis": ("ovis", "Ovis"),
345
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
346
347
348
349
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
350
351
352
353
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
354
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
355
356
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
357
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
358
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
359
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
360
361
362
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
363
    ),
364
365
366
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
367
    ),
368
369
370
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
371
    ),
372
373
374
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
375
    ),
376
377
378
379
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
380
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
381
382
383
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
384
    ),
385
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
386
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
387
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
388
389
390
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
391
    ),
392
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
393
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
394
    # [Encoder-decoder]
395
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
396
}
397
398

_SPECULATIVE_DECODING_MODELS = {
399
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
400
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
401
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
402
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
403
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
404
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
405
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
406
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
407
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
408
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
409
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
410
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
411
    "MedusaModel": ("medusa", "Medusa"),
412
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
413
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
414
415
416
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
417
}
418

419
_TRANSFORMERS_SUPPORTED_MODELS = {
420
421
422
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
423
424
425
426
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
427
428
429
}

_TRANSFORMERS_BACKEND_MODELS = {
430
    # Text generation models
431
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
450
    "TransformersForSequenceClassification": (
451
        "transformers",
452
        "TransformersForSequenceClassification",
453
    ),
454
    "TransformersMoEForSequenceClassification": (
455
        "transformers",
456
        "TransformersMoEForSequenceClassification",
457
    ),
458
459
460
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
461
    ),
462
}
463

464
_VLLM_MODELS = {
465
    **_TEXT_GENERATION_MODELS,
466
    **_EMBEDDING_MODELS,
467
    **_CROSS_ENCODER_MODELS,
468
    **_MULTIMODAL_MODELS,
469
    **_SPECULATIVE_DECODING_MODELS,
470
471
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
472
473
}

474
475
476
477
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
478
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
479

480
_PREVIOUSLY_SUPPORTED_MODELS = {
481
    "MotifForCausalLM": "0.10.2",
482
    "Phi3SmallForCausalLM": "0.9.2",
483
    "Phi4FlashForCausalLM": "0.10.2",
484
485
486
487
488
489
490
491
492
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
493

494

495
496
@dataclass(frozen=True)
class _ModelInfo:
497
    architecture: str
498
    is_text_generation_model: bool
499
    is_pooling_model: bool
500
    default_pooling_type: str
501
    supports_cross_encoding: bool
502
    supports_multimodal: bool
503
    supports_multimodal_raw_input_only: bool
504
    supports_multimodal_encoder_tp_data: bool
505
    supports_pp: bool
506
507
    has_inner_state: bool
    is_attention_free: bool
508
    is_hybrid: bool
509
    has_noops: bool
510
    supports_mamba_prefix_caching: bool
511
    supports_transcription: bool
512
    supports_transcription_only: bool
513
514

    @staticmethod
515
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
516
        return _ModelInfo(
517
            architecture=model.__name__,
518
            is_text_generation_model=is_text_generation_model(model),
519
            is_pooling_model=is_pooling_model(model),
520
            default_pooling_type=get_default_pooling_type(model),
521
            supports_cross_encoding=supports_cross_encoding(model),
522
            supports_multimodal=supports_multimodal(model),
523
524
525
526
527
528
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
529
            supports_pp=supports_pp(model),
530
531
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
532
            is_hybrid=is_hybrid(model),
533
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
534
            supports_transcription=supports_transcription(model),
535
536
537
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
538
            has_noops=has_noops(model),
539
        )
540
541


542
543
544
545
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
546

547
    @abstractmethod
548
    def load_model_cls(self) -> type[nn.Module]:
549
        raise NotImplementedError
550
551


552
553
554
555
556
557
558
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
559
    model_cls: type[nn.Module]
560
561

    @staticmethod
562
    def from_model_cls(model_cls: type[nn.Module]):
563
564
565
566
567
568
569
570
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

571
    def load_model_cls(self) -> type[nn.Module]:
572
573
574
575
576
577
578
579
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
580

581
582
583
    module_name: str
    class_name: str

584
585
586
587
588
589
590
591
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

592
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
593
594
        try:
            try:
595
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
596
597
598
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
599
600
601
602
603
                logger.debug(
                    ("Cached model info file for class %s.%s not found"),
                    self.module_name,
                    self.class_name,
                )
604
605
606
                return None

            if mi_dict["hash"] != module_hash:
607
608
609
610
611
                logger.debug(
                    ("Cached model info file for class %s.%s is stale"),
                    self.module_name,
                    self.class_name,
                )
612
613
614
615
616
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
617
            logger.debug(
618
619
620
621
                ("Cached model info for class %s.%s error. "),
                self.module_name,
                self.class_name,
            )
622
623
            return None

624
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
625
626
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
627

628
629
630
631
632
633
634
635
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
636
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
637
638
639
640
641
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
642
    def inspect_model_cls(self) -> _ModelInfo:
643
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
644
        module_hash = None
645

646
647
        if model_path.exists():
            with open(model_path, "rb") as f:
648
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
649
650
651

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
652
653
654
655
656
                logger.debug(
                    ("Loaded model info for class %s.%s from cache"),
                    self.module_name,
                    self.class_name,
                )
657
658
                return mi
            else:
659
660
661
662
663
                logger.debug(
                    ("Cache model info for class %s.%s miss. Loading model instead."),
                    self.module_name,
                    self.class_name,
                )
664
665
666

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
667
668
669
670
671
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
672
673

        # save cache file
674
675
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
676
677

        return mi
678

679
    def load_model_cls(self) -> type[nn.Module]:
680
681
682
683
684
685
686
687
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
688
) -> type[nn.Module] | None:
689
    from vllm.platforms import current_platform
690

691
    current_platform.verify_model_arch(model_arch)
692
693
694
    try:
        return model.load_model_cls()
    except Exception:
695
        logger.exception("Error in loading model architecture '%s'", model_arch)
696
        return None
697
698


699
700
701
702
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
703
) -> _ModelInfo | None:
704
705
706
    try:
        return model.inspect_model_cls()
    except Exception:
707
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
708
        return None
709
710


711
712
713
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
714
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
715

716
    def get_supported_archs(self) -> Set[str]:
717
        return self.models.keys()
718

719
720
721
    def register_model(
        self,
        model_arch: str,
722
        model_cls: type[nn.Module] | str,
723
    ) -> None:
724
725
726
        """
        Register an external model to be used in vLLM.

727
        `model_cls` can be either:
728

729
        - A [`torch.nn.Module`][] class directly referencing the model.
730
        - A string in the format `<module>:<class>` which can be used to
731
732
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
733
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
734
        """
735
736
737
738
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

739
        if model_arch in self.models:
740
741
            logger.warning(
                "Model architecture %s is already registered, and will be "
742
743
744
745
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
746
747
748
749
750
751

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
752

753
            model = _LazyRegisteredModel(*split_str)
754
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
755
            model = _RegisteredModel.from_model_cls(model_cls)
756
        else:
757
758
759
760
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
761
            raise TypeError(msg)
762

763
        self.models[model_arch] = model
764

765
    def _raise_for_unsupported(self, architectures: list[str]):
766
        all_supported_archs = self.get_supported_archs()
767

768
769
770
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
771
772
                "to be inspected. Please check the logs for more details."
            )
773

774
775
776
777
778
779
780
781
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
782
783
                    "use this model architecture."
                )
784

785
786
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
787
788
            f"Supported architectures: {all_supported_archs}"
        )
789

790
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
791
792
        if model_arch not in self.models:
            return None
793

794
        return _try_load_model_cls(model_arch, self.models[model_arch])
795

796
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
797
798
        if model_arch not in self.models:
            return None
799

800
801
802
803
804
805
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
806
    ) -> str | None:
807
808
809
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

810
811
812
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
846
                if model_config.model_impl != "transformers":
847
848
849
850
851
852
853
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
854
855
                    "'auto_map' (relevant if the model is custom)."
                )
856
857

        if not model_module.is_backend_compatible():
858
            if model_config.model_impl != "transformers":
859
                return None
860

861
862
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
863
864
                "is not compatible with vLLM."
            )
865

866
        return model_config._get_transformers_backend_cls()
867

868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
893

894
895
    def inspect_model_cls(
        self,
896
        architectures: str | list[str],
897
        model_config: ModelConfig,
898
    ) -> tuple[_ModelInfo, str]:
899
900
        if isinstance(architectures, str):
            architectures = [architectures]
901
902
        if not architectures:
            raise ValueError("No model architectures are specified")
903
904

        # Require transformers impl
905
        if model_config.model_impl == "transformers":
906
            arch = self._try_resolve_transformers(architectures[0], model_config)
907
908
909
910
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
911
        elif model_config.model_impl == "terratorch":
912
913
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
914

915
        # Fallback to transformers impl (after resolving convert_type)
916
917
918
919
920
921
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
922
923
924
925
926
927
928
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
929
            model_info = self._try_inspect_model_cls(normalized_arch)
930
            if model_info is not None:
931
                return (model_info, arch)
932

933
        # Fallback to transformers impl (before resolving runner_type)
934
935
936
937
938
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
939
940
941
942
943
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

944
        return self._raise_for_unsupported(architectures)
945

946
947
    def resolve_model_cls(
        self,
948
        architectures: str | list[str],
949
        model_config: ModelConfig,
950
    ) -> tuple[type[nn.Module], str]:
951
952
        if isinstance(architectures, str):
            architectures = [architectures]
953
954
        if not architectures:
            raise ValueError("No model architectures are specified")
955
956

        # Require transformers impl
957
        if model_config.model_impl == "transformers":
958
            arch = self._try_resolve_transformers(architectures[0], model_config)
959
960
961
962
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
963
        elif model_config.model_impl == "terratorch":
964
965
966
967
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
968

969
        # Fallback to transformers impl (after resolving convert_type)
970
971
972
973
974
975
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
976
977
978
979
980
981
982
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
983
            model_cls = self._try_load_model_cls(normalized_arch)
984
985
            if model_cls is not None:
                return (model_cls, arch)
986

987
        # Fallback to transformers impl (before resolving runner_type)
988
989
990
991
992
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
993
994
995
996
997
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

998
        return self._raise_for_unsupported(architectures)
999

1000
1001
    def is_text_generation_model(
        self,
1002
        architectures: str | list[str],
1003
        model_config: ModelConfig,
1004
    ) -> bool:
1005
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1006
        return model_cls.is_text_generation_model
1007

1008
    def is_pooling_model(
1009
        self,
1010
        architectures: str | list[str],
1011
        model_config: ModelConfig,
1012
    ) -> bool:
1013
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1014
        return model_cls.is_pooling_model
1015

1016
1017
    def is_cross_encoder_model(
        self,
1018
        architectures: str | list[str],
1019
        model_config: ModelConfig,
1020
    ) -> bool:
1021
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1022
        return model_cls.supports_cross_encoding
1023

1024
1025
    def is_multimodal_model(
        self,
1026
        architectures: str | list[str],
1027
        model_config: ModelConfig,
1028
    ) -> bool:
1029
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1030
        return model_cls.supports_multimodal
1031

1032
    def is_multimodal_raw_input_only_model(
1033
        self,
1034
        architectures: str | list[str],
1035
        model_config: ModelConfig,
1036
    ) -> bool:
1037
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1038
        return model_cls.supports_multimodal_raw_input_only
1039

1040
1041
    def is_pp_supported_model(
        self,
1042
        architectures: str | list[str],
1043
        model_config: ModelConfig,
1044
    ) -> bool:
1045
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1046
        return model_cls.supports_pp
1047

1048
1049
    def model_has_inner_state(
        self,
1050
        architectures: str | list[str],
1051
        model_config: ModelConfig,
1052
    ) -> bool:
1053
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1054
        return model_cls.has_inner_state
1055

1056
1057
    def is_attention_free_model(
        self,
1058
        architectures: str | list[str],
1059
        model_config: ModelConfig,
1060
    ) -> bool:
1061
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1062
        return model_cls.is_attention_free
1063

1064
1065
    def is_hybrid_model(
        self,
1066
        architectures: str | list[str],
1067
        model_config: ModelConfig,
1068
    ) -> bool:
1069
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1070
1071
        return model_cls.is_hybrid

1072
1073
    def is_noops_model(
        self,
1074
        architectures: str | list[str],
1075
        model_config: ModelConfig,
1076
    ) -> bool:
1077
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1078
1079
        return model_cls.has_noops

1080
1081
    def is_transcription_model(
        self,
1082
        architectures: str | list[str],
1083
        model_config: ModelConfig,
1084
    ) -> bool:
1085
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1086
1087
        return model_cls.supports_transcription

1088
1089
    def is_transcription_only_model(
        self,
1090
        architectures: str | list[str],
1091
        model_config: ModelConfig,
1092
    ) -> bool:
1093
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1094
1095
        return model_cls.supports_transcription_only

1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1106
1107
1108
1109
1110

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1111
1112
1113
1114
1115
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1116
        # `cloudpickle` allows pickling lambda functions directly
1117
        import cloudpickle
1118

1119
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1120
1121
1122

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1123
1124
1125
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1126
1127
1128
1129
1130
1131

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1132
1133
1134
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1135

1136
        with open(output_filepath, "rb") as f:
1137
1138
1139
1140
1141
1142
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1143

1144
1145
1146
1147
1148
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1149
1150
1151

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1152
1153
1154


if __name__ == "__main__":
1155
    _run()