registry.py 45.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
38
39
40
41
42
43
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
else:
    AttnTypeStr = Any
    PoolingTypeStr = Any


44
45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
    supports_cross_encoding,
50
    supports_mamba_prefix_caching,
51
52
53
54
55
56
57
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
58
    get_attn_type,
59
60
61
62
    get_default_pooling_type,
    is_pooling_model,
    is_text_generation_model,
)
63
64
65

logger = init_logger(__name__)

66
67
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
68
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
69
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
70
71
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
72
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
73
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
74
75
76
77
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
78
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
79
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
80
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
81
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
82
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
83
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
84
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
85
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
86
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
87
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
88
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
89
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
90
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
91
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
92
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
93
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
94
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
95
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
96
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
97
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
98
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
99
100
101
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
102
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
103
104
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
105
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
106
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
107
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
108
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
109
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
110
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
111
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
112
113
114
115
116
117
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
118
119
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
120
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
121
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
122
123
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
124
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
125
126
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
127
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
128
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
129
130
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
131
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
132
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
133
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
134
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
135
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
136
137
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
138
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
139
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
140
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
141
142
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
143
144
145
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
146
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
147
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
149
150
151
152
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
153
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
154
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
155
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
156
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
157
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
158
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
159
160
161
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
162
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
163
164
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
165
166
167
168
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
169
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
170
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
171
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
172
173
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
174
175
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
176
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
177
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
178
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
179
180
181
182
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
183
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
184
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
185
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
186
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
187
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
188
189
190
}

_EMBEDDING_MODELS = {
191
    # [Text-only]
192
    "BertModel": ("bert", "BertEmbeddingModel"),
193
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
194
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
195
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
196
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
197
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
198
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
199
    "GritLM": ("gritlm", "GritLM"),
200
201
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
202
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
203
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
204
    "LlamaModel": ("llama", "LlamaForCausalLM"),
205
206
    **{
        # Multiple models share the same architecture, so we include them all
207
208
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
209
210
        if arch == "LlamaForCausalLM"
    },
211
    "MistralModel": ("llama", "LlamaForCausalLM"),
212
    "ModernBertModel": ("modernbert", "ModernBertModel"),
213
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
214
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
215
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
216
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
217
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
218
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
219
220
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
221
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
222
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
223
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
224
    # [Multimodal]
225
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
226
227
228
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
229
    ),
230
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
231
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
232
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
233
234
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
235
    # models for the time being.
236
237
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
238
239
}

240
241
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
242
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
243
244
245
246
247
248
249
250
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
251
252
253
254
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
255
256
257
258
259
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
260
    # [Auto-converted (see adapters.py)]
261
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),  # noqa: E501,
262
263
}

264
_MULTIMODAL_MODELS = {
265
    # [Decoder-only]
266
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
267
268
269
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
270
    ),
271
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
272
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
273
274
275
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
276
    ),
277
278
279
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
280
    ),
281
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
282
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
Roger Wang's avatar
Roger Wang committed
283
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
284
285
286
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
287
    ),
288
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
289
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
290
291
292
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
293
    ),
294
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
295
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
296
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
297
298
299
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
300
    ),
301
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
302
303
304
305
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
306
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
307
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
308
309
310
311
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
312
313
314
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
315
    ),
316
317
318
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
319
    ),
320
321
322
323
324
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
325
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
326
327
328
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
329
    ),
330
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
331
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
332
333
334
335
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
336
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
337
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
338
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
339
340
341
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
342
    ),
343
344
345
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
346
    ),
347
348
349
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
350
    ),
351
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
352
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
353
354
355
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
356
    ),
357
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
358
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
359
360
361
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
362
    ),
363
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
364
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
365
    "Ovis": ("ovis", "Ovis"),
366
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
367
368
369
370
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
371
372
373
374
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
375
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
376
377
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
378
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
379
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
380
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
381
382
383
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
384
    ),
385
386
387
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
388
    ),
389
390
391
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
392
    ),
393
394
395
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
396
    ),
397
398
399
400
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
401
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
402
403
404
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
405
    ),
406
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
407
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
408
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
409
410
411
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
412
    ),
413
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
414
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
415
    # [Encoder-decoder]
416
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
417
}
418
419

_SPECULATIVE_DECODING_MODELS = {
420
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
421
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
422
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
423
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
424
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
425
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
426
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
427
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
428
429
430
431
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
432
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
433
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
434
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
XuruiYang's avatar
XuruiYang committed
435
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
436
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
437
    "MedusaModel": ("medusa", "Medusa"),
438
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
439
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
440
441
442
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
443
}
444

445
_TRANSFORMERS_SUPPORTED_MODELS = {
446
447
448
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
449
450
451
452
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
453
454
455
}

_TRANSFORMERS_BACKEND_MODELS = {
456
    # Text generation models
457
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
476
    "TransformersForSequenceClassification": (
477
        "transformers",
478
        "TransformersForSequenceClassification",
479
    ),
480
    "TransformersMoEForSequenceClassification": (
481
        "transformers",
482
        "TransformersMoEForSequenceClassification",
483
    ),
484
485
486
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
487
    ),
488
}
489

490
_VLLM_MODELS = {
491
    **_TEXT_GENERATION_MODELS,
492
    **_EMBEDDING_MODELS,
493
    **_CROSS_ENCODER_MODELS,
494
    **_MULTIMODAL_MODELS,
495
    **_SPECULATIVE_DECODING_MODELS,
496
497
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
498
499
}

500
501
502
503
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
504
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
505

506
_PREVIOUSLY_SUPPORTED_MODELS = {
507
    "MotifForCausalLM": "0.10.2",
508
    "Phi3SmallForCausalLM": "0.9.2",
509
    "Phi4FlashForCausalLM": "0.10.2",
510
511
512
513
514
515
516
517
518
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
519

520

521
522
@dataclass(frozen=True)
class _ModelInfo:
523
    architecture: str
524
    is_text_generation_model: bool
525
    is_pooling_model: bool
526
527
    attn_type: AttnTypeStr
    default_pooling_type: PoolingTypeStr
528
    supports_cross_encoding: bool
529
    supports_multimodal: bool
530
    supports_multimodal_raw_input_only: bool
531
    supports_multimodal_encoder_tp_data: bool
532
    supports_pp: bool
533
534
    has_inner_state: bool
    is_attention_free: bool
535
    is_hybrid: bool
536
    has_noops: bool
537
    supports_mamba_prefix_caching: bool
538
    supports_transcription: bool
539
    supports_transcription_only: bool
540
541

    @staticmethod
542
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
543
        return _ModelInfo(
544
            architecture=model.__name__,
545
            is_text_generation_model=is_text_generation_model(model),
546
            is_pooling_model=is_pooling_model(model),
547
            default_pooling_type=get_default_pooling_type(model),
548
            attn_type=get_attn_type(model),
549
            supports_cross_encoding=supports_cross_encoding(model),
550
            supports_multimodal=supports_multimodal(model),
551
552
553
554
555
556
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
557
            supports_pp=supports_pp(model),
558
559
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
560
            is_hybrid=is_hybrid(model),
561
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
562
            supports_transcription=supports_transcription(model),
563
564
565
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
566
            has_noops=has_noops(model),
567
        )
568
569


570
571
572
573
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
574

575
    @abstractmethod
576
    def load_model_cls(self) -> type[nn.Module]:
577
        raise NotImplementedError
578
579


580
581
582
583
584
585
586
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
587
    model_cls: type[nn.Module]
588
589

    @staticmethod
590
    def from_model_cls(model_cls: type[nn.Module]):
591
592
593
594
595
596
597
598
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

599
    def load_model_cls(self) -> type[nn.Module]:
600
601
602
603
604
605
606
607
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
608

609
610
611
    module_name: str
    class_name: str

612
613
614
615
616
617
618
619
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

620
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
621
622
        try:
            try:
623
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
624
625
626
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
627
                logger.debug(
628
                    "Cached model info file for class %s.%s not found",
629
630
631
                    self.module_name,
                    self.class_name,
                )
632
633
634
                return None

            if mi_dict["hash"] != module_hash:
635
                logger.debug(
636
                    "Cached model info file for class %s.%s is stale",
637
638
639
                    self.module_name,
                    self.class_name,
                )
640
641
642
643
644
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
645
            logger.debug(
646
                "Cached model info for class %s.%s error. ",
647
648
649
                self.module_name,
                self.class_name,
            )
650
651
            return None

652
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
653
654
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
655

656
657
658
659
660
661
662
663
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
664
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
665
666
667
668
669
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
670
    def inspect_model_cls(self) -> _ModelInfo:
671
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
672
        module_hash = None
673

674
675
        if model_path.exists():
            with open(model_path, "rb") as f:
676
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
677
678
679

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
680
                logger.debug(
681
                    "Loaded model info for class %s.%s from cache",
682
683
684
                    self.module_name,
                    self.class_name,
                )
685
686
                return mi
            else:
687
                logger.debug(
688
                    "Cache model info for class %s.%s miss. Loading model instead.",
689
690
691
                    self.module_name,
                    self.class_name,
                )
692
693
694

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
695
696
697
698
699
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
700
701

        # save cache file
702
703
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
704
705

        return mi
706

707
    def load_model_cls(self) -> type[nn.Module]:
708
709
710
711
712
713
714
715
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
716
) -> type[nn.Module] | None:
717
    from vllm.platforms import current_platform
718

719
    current_platform.verify_model_arch(model_arch)
720
721
722
    try:
        return model.load_model_cls()
    except Exception:
723
        logger.exception("Error in loading model architecture '%s'", model_arch)
724
        return None
725
726


727
728
729
730
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
731
) -> _ModelInfo | None:
732
733
734
    try:
        return model.inspect_model_cls()
    except Exception:
735
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
736
        return None
737
738


739
740
741
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
742
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
743

744
    def get_supported_archs(self) -> Set[str]:
745
        return self.models.keys()
746

747
748
749
    def register_model(
        self,
        model_arch: str,
750
        model_cls: type[nn.Module] | str,
751
    ) -> None:
752
753
754
        """
        Register an external model to be used in vLLM.

755
        `model_cls` can be either:
756

757
        - A [`torch.nn.Module`][] class directly referencing the model.
758
        - A string in the format `<module>:<class>` which can be used to
759
760
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
761
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
762
        """
763
764
765
766
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

767
        if model_arch in self.models:
768
769
            logger.warning(
                "Model architecture %s is already registered, and will be "
770
771
772
773
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
774
775
776
777
778
779

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
780

781
            model = _LazyRegisteredModel(*split_str)
782
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
783
            model = _RegisteredModel.from_model_cls(model_cls)
784
        else:
785
786
787
788
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
789
            raise TypeError(msg)
790

791
        self.models[model_arch] = model
792

793
    def _raise_for_unsupported(self, architectures: list[str]):
794
        all_supported_archs = self.get_supported_archs()
795

796
797
798
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
799
800
                "to be inspected. Please check the logs for more details."
            )
801

802
803
804
805
806
807
808
809
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
810
811
                    "use this model architecture."
                )
812

813
814
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
815
816
            f"Supported architectures: {all_supported_archs}"
        )
817

818
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
819
820
        if model_arch not in self.models:
            return None
821

822
        return _try_load_model_cls(model_arch, self.models[model_arch])
823

824
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
825
826
        if model_arch not in self.models:
            return None
827

828
829
830
831
832
833
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
834
    ) -> str | None:
835
836
837
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

838
839
840
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
874
                if model_config.model_impl != "transformers":
875
876
877
878
879
880
881
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
882
883
                    "'auto_map' (relevant if the model is custom)."
                )
884
885

        if not model_module.is_backend_compatible():
886
            if model_config.model_impl != "transformers":
887
                return None
888

889
890
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
891
892
                "is not compatible with vLLM."
            )
893

894
        return model_config._get_transformers_backend_cls()
895

896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
921

922
923
    def inspect_model_cls(
        self,
924
        architectures: str | list[str],
925
        model_config: ModelConfig,
926
    ) -> tuple[_ModelInfo, str]:
927
928
        if isinstance(architectures, str):
            architectures = [architectures]
929
930
        if not architectures:
            raise ValueError("No model architectures are specified")
931
932

        # Require transformers impl
933
        if model_config.model_impl == "transformers":
934
            arch = self._try_resolve_transformers(architectures[0], model_config)
935
936
937
938
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
939
        elif model_config.model_impl == "terratorch":
940
941
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
942

943
        # Fallback to transformers impl (after resolving convert_type)
944
945
946
947
948
949
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
950
951
952
953
954
955
956
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
957
            model_info = self._try_inspect_model_cls(normalized_arch)
958
            if model_info is not None:
959
                return (model_info, arch)
960

961
        # Fallback to transformers impl (before resolving runner_type)
962
963
964
965
966
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
967
968
969
970
971
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

972
        return self._raise_for_unsupported(architectures)
973

974
975
    def resolve_model_cls(
        self,
976
        architectures: str | list[str],
977
        model_config: ModelConfig,
978
    ) -> tuple[type[nn.Module], str]:
979
980
        if isinstance(architectures, str):
            architectures = [architectures]
981
982
        if not architectures:
            raise ValueError("No model architectures are specified")
983
984

        # Require transformers impl
985
        if model_config.model_impl == "transformers":
986
            arch = self._try_resolve_transformers(architectures[0], model_config)
987
988
989
990
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
991
        elif model_config.model_impl == "terratorch":
992
993
994
995
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
996

997
        # Fallback to transformers impl (after resolving convert_type)
998
999
1000
1001
1002
1003
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1004
1005
1006
1007
1008
1009
1010
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1011
            model_cls = self._try_load_model_cls(normalized_arch)
1012
1013
            if model_cls is not None:
                return (model_cls, arch)
1014

1015
        # Fallback to transformers impl (before resolving runner_type)
1016
1017
1018
1019
1020
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1021
1022
1023
1024
1025
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1026
        return self._raise_for_unsupported(architectures)
1027

1028
1029
    def is_text_generation_model(
        self,
1030
        architectures: str | list[str],
1031
        model_config: ModelConfig,
1032
    ) -> bool:
1033
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1034
        return model_cls.is_text_generation_model
1035

1036
    def is_pooling_model(
1037
        self,
1038
        architectures: str | list[str],
1039
        model_config: ModelConfig,
1040
    ) -> bool:
1041
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1042
        return model_cls.is_pooling_model
1043

1044
1045
    def is_cross_encoder_model(
        self,
1046
        architectures: str | list[str],
1047
        model_config: ModelConfig,
1048
    ) -> bool:
1049
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1050
        return model_cls.supports_cross_encoding
1051

1052
1053
    def is_multimodal_model(
        self,
1054
        architectures: str | list[str],
1055
        model_config: ModelConfig,
1056
    ) -> bool:
1057
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1058
        return model_cls.supports_multimodal
1059

1060
    def is_multimodal_raw_input_only_model(
1061
        self,
1062
        architectures: str | list[str],
1063
        model_config: ModelConfig,
1064
    ) -> bool:
1065
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1066
        return model_cls.supports_multimodal_raw_input_only
1067

1068
1069
    def is_pp_supported_model(
        self,
1070
        architectures: str | list[str],
1071
        model_config: ModelConfig,
1072
    ) -> bool:
1073
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1074
        return model_cls.supports_pp
1075

1076
1077
    def model_has_inner_state(
        self,
1078
        architectures: str | list[str],
1079
        model_config: ModelConfig,
1080
    ) -> bool:
1081
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1082
        return model_cls.has_inner_state
1083

1084
1085
    def is_attention_free_model(
        self,
1086
        architectures: str | list[str],
1087
        model_config: ModelConfig,
1088
    ) -> bool:
1089
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1090
        return model_cls.is_attention_free
1091

1092
1093
    def is_hybrid_model(
        self,
1094
        architectures: str | list[str],
1095
        model_config: ModelConfig,
1096
    ) -> bool:
1097
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1098
1099
        return model_cls.is_hybrid

1100
1101
    def is_noops_model(
        self,
1102
        architectures: str | list[str],
1103
        model_config: ModelConfig,
1104
    ) -> bool:
1105
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1106
1107
        return model_cls.has_noops

1108
1109
    def is_transcription_model(
        self,
1110
        architectures: str | list[str],
1111
        model_config: ModelConfig,
1112
    ) -> bool:
1113
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1114
1115
        return model_cls.supports_transcription

1116
1117
    def is_transcription_only_model(
        self,
1118
        architectures: str | list[str],
1119
        model_config: ModelConfig,
1120
    ) -> bool:
1121
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1122
1123
        return model_cls.supports_transcription_only

1124

1125
1126
1127
1128
1129
1130
1131
1132
1133
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1134
1135
1136
1137
1138

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1139
1140
1141
1142
1143
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1144
        # `cloudpickle` allows pickling lambda functions directly
1145
        import cloudpickle
1146

1147
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1148
1149
1150

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1151
1152
1153
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1154
1155
1156
1157
1158
1159

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1160
1161
1162
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1163

1164
        with open(output_filepath, "rb") as f:
1165
1166
1167
1168
1169
1170
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1171

1172
1173
1174
1175
1176
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1177
1178
1179

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1180
1181
1182


if __name__ == "__main__":
1183
    _run()