registry.py 44.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import hashlib
9
import importlib
10
import json
11
import os
12
import pickle
13
14
import subprocess
import sys
15
import tempfile
16
from abc import ABC, abstractmethod
17
from collections.abc import Callable, Set
18
from dataclasses import asdict, dataclass, field
19
from functools import lru_cache
20
from pathlib import Path
21
from typing import TypeVar
22
23

import torch.nn as nn
24
import transformers
25

26
from vllm import envs
27
28
29
30
31
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
32
from vllm.logger import init_logger
33
from vllm.logging_utils import logtime
34
35
36
37
38
39
40
41
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module

from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
42
    supports_mamba_prefix_caching,
43
44
45
46
47
48
49
50
51
52
53
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
54
55
56

logger = init_logger(__name__)

57
58
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
59
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
60
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
61
62
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
63
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
64
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
65
66
67
68
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
69
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
70
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
71
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
72
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
73
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
74
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
75
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
76
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
77
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
78
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
79
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
80
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
81
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
82
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
83
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
84
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
85
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
86
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
87
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
88
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
89
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
90
91
92
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
93
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
94
95
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
96
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
97
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
98
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
99
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
100
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
101
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
102
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
103
104
105
106
107
108
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
109
110
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
111
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
112
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
113
114
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
115
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
116
117
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
118
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
119
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
120
121
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
122
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
123
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
124
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
125
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
126
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
127
128
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
129
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
130
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
131
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
132
133
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
134
135
136
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
137
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
138
139
140
141
142
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
143
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
144
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
145
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
146
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
147
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
148
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
149
150
151
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
152
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
153
154
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
155
156
157
158
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
159
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
160
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
161
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
162
163
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
164
165
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
166
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
167
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
168
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
169
170
171
172
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
173
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
174
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
175
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
176
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
177
178
179
}

_EMBEDDING_MODELS = {
180
    # [Text-only]
181
    "BertModel": ("bert", "BertEmbeddingModel"),
182
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
183
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
184
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
185
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
186
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
187
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
188
    "GritLM": ("gritlm", "GritLM"),
189
190
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
191
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
192
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
193
    "LlamaModel": ("llama", "LlamaForCausalLM"),
194
195
    **{
        # Multiple models share the same architecture, so we include them all
196
197
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
198
199
        if arch == "LlamaForCausalLM"
    },
200
    "MistralModel": ("llama", "LlamaForCausalLM"),
201
    "ModernBertModel": ("modernbert", "ModernBertModel"),
202
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
203
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
204
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
205
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
206
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
207
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
208
209
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
210
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
211
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
212
    # [Multimodal]
213
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
214
215
216
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
217
    ),
218
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
219
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
220
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
221
222
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
223
    # models for the time being.
224
225
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
226
227
}

228
229
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
230
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
231
232
233
234
235
236
237
238
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
239
240
241
242
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
243
244
245
246
247
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
248
    # [Auto-converted (see adapters.py)]
249
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
250
251
}

252
_MULTIMODAL_MODELS = {
253
    # [Decoder-only]
254
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
255
256
257
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
258
    ),
259
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
260
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
261
262
263
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
264
    ),
265
266
267
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
268
    ),
269
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
270
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
271
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
272
273
274
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
275
    ),
276
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
277
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
278
279
280
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
281
    ),
282
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
283
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
284
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
285
286
287
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
288
    ),
289
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
290
291
292
293
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
294
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
295
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
296
297
298
299
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
300
301
302
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
303
    ),
304
305
306
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
307
    ),
308
309
310
311
312
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
313
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
314
315
316
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
317
    ),
318
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
319
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
320
321
322
323
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
324
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
325
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
326
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
327
328
329
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
330
    ),
331
332
333
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
334
    ),
335
336
337
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
338
    ),
339
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
340
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
341
342
343
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
344
    ),
345
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
346
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
347
348
349
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
350
    ),
351
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
352
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
353
    "Ovis": ("ovis", "Ovis"),
354
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
355
356
357
358
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
359
360
361
362
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
363
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
364
365
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
366
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
367
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
368
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
369
370
371
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
372
    ),
373
374
375
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
376
    ),
377
378
379
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
380
    ),
381
382
383
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
384
    ),
385
386
387
388
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
389
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
390
391
392
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
393
    ),
394
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
395
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
396
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
397
398
399
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
400
    ),
401
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
402
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
403
    # [Encoder-decoder]
404
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
405
}
406
407

_SPECULATIVE_DECODING_MODELS = {
408
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
409
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
410
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
411
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
412
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
413
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
414
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
415
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
416
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
417
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
418
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
419
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
420
    "MedusaModel": ("medusa", "Medusa"),
421
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
422
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
423
424
425
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
426
}
427

428
_TRANSFORMERS_SUPPORTED_MODELS = {
429
430
431
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
432
433
434
435
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
436
437
438
}

_TRANSFORMERS_BACKEND_MODELS = {
439
    # Text generation models
440
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
459
    "TransformersForSequenceClassification": (
460
        "transformers",
461
        "TransformersForSequenceClassification",
462
    ),
463
    "TransformersMoEForSequenceClassification": (
464
        "transformers",
465
        "TransformersMoEForSequenceClassification",
466
    ),
467
468
469
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
470
    ),
471
}
472

473
_VLLM_MODELS = {
474
    **_TEXT_GENERATION_MODELS,
475
    **_EMBEDDING_MODELS,
476
    **_CROSS_ENCODER_MODELS,
477
    **_MULTIMODAL_MODELS,
478
    **_SPECULATIVE_DECODING_MODELS,
479
480
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
481
482
}

483
484
485
486
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
487
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
488

489
_PREVIOUSLY_SUPPORTED_MODELS = {
490
    "MotifForCausalLM": "0.10.2",
491
    "Phi3SmallForCausalLM": "0.9.2",
492
    "Phi4FlashForCausalLM": "0.10.2",
493
494
495
496
497
498
499
500
501
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
502

503

504
505
@dataclass(frozen=True)
class _ModelInfo:
506
    architecture: str
507
    is_text_generation_model: bool
508
    is_pooling_model: bool
509
    default_pooling_type: str
510
    supports_cross_encoding: bool
511
    supports_multimodal: bool
512
    supports_multimodal_raw_input_only: bool
513
    supports_multimodal_encoder_tp_data: bool
514
    supports_pp: bool
515
516
    has_inner_state: bool
    is_attention_free: bool
517
    is_hybrid: bool
518
    has_noops: bool
519
    supports_mamba_prefix_caching: bool
520
    supports_transcription: bool
521
    supports_transcription_only: bool
522
523

    @staticmethod
524
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
525
        return _ModelInfo(
526
            architecture=model.__name__,
527
            is_text_generation_model=is_text_generation_model(model),
528
            is_pooling_model=is_pooling_model(model),
529
            default_pooling_type=get_default_pooling_type(model),
530
            supports_cross_encoding=supports_cross_encoding(model),
531
            supports_multimodal=supports_multimodal(model),
532
533
534
535
536
537
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
538
            supports_pp=supports_pp(model),
539
540
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
541
            is_hybrid=is_hybrid(model),
542
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
543
            supports_transcription=supports_transcription(model),
544
545
546
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
547
            has_noops=has_noops(model),
548
        )
549
550


551
552
553
554
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
555

556
    @abstractmethod
557
    def load_model_cls(self) -> type[nn.Module]:
558
        raise NotImplementedError
559
560


561
562
563
564
565
566
567
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
568
    model_cls: type[nn.Module]
569
570

    @staticmethod
571
    def from_model_cls(model_cls: type[nn.Module]):
572
573
574
575
576
577
578
579
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

580
    def load_model_cls(self) -> type[nn.Module]:
581
582
583
584
585
586
587
588
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
589

590
591
592
    module_name: str
    class_name: str

593
594
595
596
597
598
599
600
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

601
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
602
603
        try:
            try:
604
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
605
606
607
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
608
                logger.debug(
609
                    "Cached model info file for class %s.%s not found",
610
611
612
                    self.module_name,
                    self.class_name,
                )
613
614
615
                return None

            if mi_dict["hash"] != module_hash:
616
                logger.debug(
617
                    "Cached model info file for class %s.%s is stale",
618
619
620
                    self.module_name,
                    self.class_name,
                )
621
622
623
624
625
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
626
            logger.debug(
627
                "Cached model info for class %s.%s error. ",
628
629
630
                self.module_name,
                self.class_name,
            )
631
632
            return None

633
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
634
635
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
636

637
638
639
640
641
642
643
644
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
645
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
646
647
648
649
650
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
651
    def inspect_model_cls(self) -> _ModelInfo:
652
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
653
        module_hash = None
654

655
656
        if model_path.exists():
            with open(model_path, "rb") as f:
657
                module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
658
659
660

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
661
                logger.debug(
662
                    "Loaded model info for class %s.%s from cache",
663
664
665
                    self.module_name,
                    self.class_name,
                )
666
667
                return mi
            else:
668
                logger.debug(
669
                    "Cache model info for class %s.%s miss. Loading model instead.",
670
671
672
                    self.module_name,
                    self.class_name,
                )
673
674
675

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
676
677
678
679
680
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
681
682

        # save cache file
683
684
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
685
686

        return mi
687

688
    def load_model_cls(self) -> type[nn.Module]:
689
690
691
692
693
694
695
696
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
697
) -> type[nn.Module] | None:
698
    from vllm.platforms import current_platform
699

700
    current_platform.verify_model_arch(model_arch)
701
702
703
    try:
        return model.load_model_cls()
    except Exception:
704
        logger.exception("Error in loading model architecture '%s'", model_arch)
705
        return None
706
707


708
709
710
711
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
712
) -> _ModelInfo | None:
713
714
715
    try:
        return model.inspect_model_cls()
    except Exception:
716
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
717
        return None
718
719


720
721
722
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
723
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
724

725
    def get_supported_archs(self) -> Set[str]:
726
        return self.models.keys()
727

728
729
730
    def register_model(
        self,
        model_arch: str,
731
        model_cls: type[nn.Module] | str,
732
    ) -> None:
733
734
735
        """
        Register an external model to be used in vLLM.

736
        `model_cls` can be either:
737

738
        - A [`torch.nn.Module`][] class directly referencing the model.
739
        - A string in the format `<module>:<class>` which can be used to
740
741
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
742
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
743
        """
744
745
746
747
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

748
        if model_arch in self.models:
749
750
            logger.warning(
                "Model architecture %s is already registered, and will be "
751
752
753
754
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
755
756
757
758
759
760

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
761

762
            model = _LazyRegisteredModel(*split_str)
763
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
764
            model = _RegisteredModel.from_model_cls(model_cls)
765
        else:
766
767
768
769
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
770
            raise TypeError(msg)
771

772
        self.models[model_arch] = model
773

774
    def _raise_for_unsupported(self, architectures: list[str]):
775
        all_supported_archs = self.get_supported_archs()
776

777
778
779
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
780
781
                "to be inspected. Please check the logs for more details."
            )
782

783
784
785
786
787
788
789
790
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
791
792
                    "use this model architecture."
                )
793

794
795
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
796
797
            f"Supported architectures: {all_supported_archs}"
        )
798

799
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
800
801
        if model_arch not in self.models:
            return None
802

803
        return _try_load_model_cls(model_arch, self.models[model_arch])
804

805
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
806
807
        if model_arch not in self.models:
            return None
808

809
810
811
812
813
814
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
815
    ) -> str | None:
816
817
818
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

819
820
821
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
855
                if model_config.model_impl != "transformers":
856
857
858
859
860
861
862
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
863
864
                    "'auto_map' (relevant if the model is custom)."
                )
865
866

        if not model_module.is_backend_compatible():
867
            if model_config.model_impl != "transformers":
868
                return None
869

870
871
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
872
873
                "is not compatible with vLLM."
            )
874

875
        return model_config._get_transformers_backend_cls()
876

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
902

903
904
    def inspect_model_cls(
        self,
905
        architectures: str | list[str],
906
        model_config: ModelConfig,
907
    ) -> tuple[_ModelInfo, str]:
908
909
        if isinstance(architectures, str):
            architectures = [architectures]
910
911
        if not architectures:
            raise ValueError("No model architectures are specified")
912
913

        # Require transformers impl
914
        if model_config.model_impl == "transformers":
915
            arch = self._try_resolve_transformers(architectures[0], model_config)
916
917
918
919
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
920
        elif model_config.model_impl == "terratorch":
921
922
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
923

924
        # Fallback to transformers impl (after resolving convert_type)
925
926
927
928
929
930
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
931
932
933
934
935
936
937
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
938
            model_info = self._try_inspect_model_cls(normalized_arch)
939
            if model_info is not None:
940
                return (model_info, arch)
941

942
        # Fallback to transformers impl (before resolving runner_type)
943
944
945
946
947
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
948
949
950
951
952
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

953
        return self._raise_for_unsupported(architectures)
954

955
956
    def resolve_model_cls(
        self,
957
        architectures: str | list[str],
958
        model_config: ModelConfig,
959
    ) -> tuple[type[nn.Module], str]:
960
961
        if isinstance(architectures, str):
            architectures = [architectures]
962
963
        if not architectures:
            raise ValueError("No model architectures are specified")
964
965

        # Require transformers impl
966
        if model_config.model_impl == "transformers":
967
            arch = self._try_resolve_transformers(architectures[0], model_config)
968
969
970
971
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
972
        elif model_config.model_impl == "terratorch":
973
974
975
976
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
977

978
        # Fallback to transformers impl (after resolving convert_type)
979
980
981
982
983
984
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
985
986
987
988
989
990
991
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
992
            model_cls = self._try_load_model_cls(normalized_arch)
993
994
            if model_cls is not None:
                return (model_cls, arch)
995

996
        # Fallback to transformers impl (before resolving runner_type)
997
998
999
1000
1001
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1002
1003
1004
1005
1006
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1007
        return self._raise_for_unsupported(architectures)
1008

1009
1010
    def is_text_generation_model(
        self,
1011
        architectures: str | list[str],
1012
        model_config: ModelConfig,
1013
    ) -> bool:
1014
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1015
        return model_cls.is_text_generation_model
1016

1017
    def is_pooling_model(
1018
        self,
1019
        architectures: str | list[str],
1020
        model_config: ModelConfig,
1021
    ) -> bool:
1022
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1023
        return model_cls.is_pooling_model
1024

1025
1026
    def is_cross_encoder_model(
        self,
1027
        architectures: str | list[str],
1028
        model_config: ModelConfig,
1029
    ) -> bool:
1030
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1031
        return model_cls.supports_cross_encoding
1032

1033
1034
    def is_multimodal_model(
        self,
1035
        architectures: str | list[str],
1036
        model_config: ModelConfig,
1037
    ) -> bool:
1038
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1039
        return model_cls.supports_multimodal
1040

1041
    def is_multimodal_raw_input_only_model(
1042
        self,
1043
        architectures: str | list[str],
1044
        model_config: ModelConfig,
1045
    ) -> bool:
1046
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1047
        return model_cls.supports_multimodal_raw_input_only
1048

1049
1050
    def is_pp_supported_model(
        self,
1051
        architectures: str | list[str],
1052
        model_config: ModelConfig,
1053
    ) -> bool:
1054
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1055
        return model_cls.supports_pp
1056

1057
1058
    def model_has_inner_state(
        self,
1059
        architectures: str | list[str],
1060
        model_config: ModelConfig,
1061
    ) -> bool:
1062
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1063
        return model_cls.has_inner_state
1064

1065
1066
    def is_attention_free_model(
        self,
1067
        architectures: str | list[str],
1068
        model_config: ModelConfig,
1069
    ) -> bool:
1070
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1071
        return model_cls.is_attention_free
1072

1073
1074
    def is_hybrid_model(
        self,
1075
        architectures: str | list[str],
1076
        model_config: ModelConfig,
1077
    ) -> bool:
1078
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1079
1080
        return model_cls.is_hybrid

1081
1082
    def is_noops_model(
        self,
1083
        architectures: str | list[str],
1084
        model_config: ModelConfig,
1085
    ) -> bool:
1086
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1087
1088
        return model_cls.has_noops

1089
1090
    def is_transcription_model(
        self,
1091
        architectures: str | list[str],
1092
        model_config: ModelConfig,
1093
    ) -> bool:
1094
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1095
1096
        return model_cls.supports_transcription

1097
1098
    def is_transcription_only_model(
        self,
1099
        architectures: str | list[str],
1100
        model_config: ModelConfig,
1101
    ) -> bool:
1102
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1103
1104
        return model_cls.supports_transcription_only

1105

1106
1107
1108
1109
1110
1111
1112
1113
1114
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1115
1116
1117
1118
1119

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1120
1121
1122
1123
1124
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1125
        # `cloudpickle` allows pickling lambda functions directly
1126
        import cloudpickle
1127

1128
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1129
1130
1131

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1132
1133
1134
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1135
1136
1137
1138
1139
1140

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1141
1142
1143
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1144

1145
        with open(output_filepath, "rb") as f:
1146
1147
1148
1149
1150
1151
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1152

1153
1154
1155
1156
1157
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1158
1159
1160

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1161
1162
1163


if __name__ == "__main__":
1164
    _run()