registry.py 50.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
113
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
114
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
115
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
117
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
118
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
119
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
120
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
121
122
123
124
125
126
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
127
128
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
129
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
130
131
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
132
133
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
134
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
135
136
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
137
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
138
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
139
140
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
141
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
142
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
143
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
144
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
145
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
146
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
147
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
149
150
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
151
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
152
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
153
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
154
155
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
156
157
158
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
159
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
160
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
161
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
162
163
164
165
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
166
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
167
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
168
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
169
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
171
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
172
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
174
175
176
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
177
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
178
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
179
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
180
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
181
182
183
184
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
185
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
186
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
187
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
188
189
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
190
191
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
192
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
193
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
194
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
195
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
196
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
197
198
199
200
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
201
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
204
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
205
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
206
207
208
}

_EMBEDDING_MODELS = {
209
    # [Text-only]
210
    "BertModel": ("bert", "BertEmbeddingModel"),
211
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
212
    "HF_ColBERT": ("colbert", "ColBERTModel"),
213
214
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
215
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
216
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
217
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
218
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
219
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
220
    "GritLM": ("gritlm", "GritLM"),
221
222
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
223
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
224
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
225
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
226
    "LlamaModel": ("llama", "LlamaForCausalLM"),
227
228
    **{
        # Multiple models share the same architecture, so we include them all
229
230
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
231
232
        if arch == "LlamaForCausalLM"
    },
233
    "MistralModel": ("llama", "LlamaForCausalLM"),
234
    "ModernBertModel": ("modernbert", "ModernBertModel"),
235
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
236
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
237
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
239
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
240
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
241
242
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
243
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
244
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
245
246
247
248
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
249
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
250
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
251
    # [Multimodal]
252
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
253
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
254
255
256
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
257
    ),
258
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
259
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
260
261
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
262
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
263
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
264
265
266
267
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
268
269
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
270
    # models for the time being.
271
272
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
273
274
}

275
276
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
277
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
278
279
280
281
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
282
283
284
285
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
286
287
288
289
    ),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
290
    ),
291
292
293
294
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
295
296
297
298
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
299
300
301
302
303
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
304
305
}

306
_MULTIMODAL_MODELS = {
307
    # [Decoder-only]
308
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
309
310
311
312
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
313
314
315
316
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
317
318
319
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
320
    ),
321
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
322
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
323
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
324
325
326
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
327
    ),
328
329
330
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
331
    ),
332
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
333
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
334
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
335
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
336
337
338
339
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
340
341
342
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
343
    ),
344
345
346
347
    "FireRedASR2ForConditionalGeneration": (
        "fireredasr2",
        "FireRedASR2ForConditionalGeneration",
    ),
348
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
349
350
351
352
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
353
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
354
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
355
356
357
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
358
    ),
359
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
360
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
361
362
363
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
364
365
366
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
367
    ),
368
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
369
370
371
372
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
373
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
374
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
375
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
376
377
378
379
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
380
381
382
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
383
    ),
384
385
386
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
387
    ),
zxy's avatar
zxy committed
388
389
390
391
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
392
393
394
395
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
396
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
397
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
398
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
399
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
400
401
402
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
403
    ),
404
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
405
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
406
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
407
408
409
410
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
411
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
412
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
413
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
414
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
415
416
417
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
418
    ),
419
420
421
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
422
    ),
423
424
425
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
426
    ),
427
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
428
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
429
430
431
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
432
    ),
433
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
434
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
435
436
437
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
438
    ),
439
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
440
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
441
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
442
443
444
445
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
446
    "Ovis": ("ovis", "Ovis"),
447
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
448
449
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
450
451
452
453
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
454
455
456
457
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
458
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
459
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
460
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
461
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
462
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
463
464
465
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
466
    ),
467
468
469
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
470
    ),
471
472
473
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
474
    ),
475
476
477
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
478
    ),
479
480
481
482
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
483
484
485
486
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
487
488
489
490
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
491
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
492
493
494
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
495
    ),
496
497
498
499
500
501
502
503
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
504
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
505
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
506
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
507
508
509
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
510
    ),
511
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
512
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
513
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
514
    # [Encoder-decoder]
515
516
517
518
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
519
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
520
}
521
522

_SPECULATIVE_DECODING_MODELS = {
523
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
524
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
525
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
526
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
527
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
528
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
529
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
530
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
531
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
532
533
534
535
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
536
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
537
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
538
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
539
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
540
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
541
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
542
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
543
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
544
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
545
    "MedusaModel": ("medusa", "Medusa"),
546
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
547
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
548
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
549
550
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
551
552
553
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
554
}
555

556
_TRANSFORMERS_SUPPORTED_MODELS = {
557
558
559
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
560
561
562
563
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
564
565
566
}

_TRANSFORMERS_BACKEND_MODELS = {
567
    # Text generation models
568
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
587
    "TransformersForSequenceClassification": (
588
        "transformers",
589
        "TransformersForSequenceClassification",
590
    ),
591
    "TransformersMoEForSequenceClassification": (
592
        "transformers",
593
        "TransformersMoEForSequenceClassification",
594
    ),
595
596
597
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
598
    ),
599
}
600

601
_VLLM_MODELS = {
602
    **_TEXT_GENERATION_MODELS,
603
    **_EMBEDDING_MODELS,
604
    **_CROSS_ENCODER_MODELS,
605
    **_MULTIMODAL_MODELS,
606
    **_SPECULATIVE_DECODING_MODELS,
607
608
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
609
610
}

611
612
613
614
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
615
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
616

617
_PREVIOUSLY_SUPPORTED_MODELS = {
618
    "MotifForCausalLM": "0.10.2",
619
    "Phi3SmallForCausalLM": "0.9.2",
620
    "Phi4FlashForCausalLM": "0.10.2",
621
    "Phi4MultimodalForCausalLM": "0.12.0",
622
623
624
625
626
627
628
629
630
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
631

632

633
634
@dataclass(frozen=True)
class _ModelInfo:
635
    architecture: str
636
    is_text_generation_model: bool
637
    is_pooling_model: bool
638
    attn_type: AttnTypeStr
639
640
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
641
    supports_cross_encoding: bool
642
    supports_late_interaction: bool
643
    supports_multimodal: bool
644
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
645
    requires_raw_input_tokens: bool
646
    supports_multimodal_encoder_tp_data: bool
647
    supports_pp: bool
648
649
    has_inner_state: bool
    is_attention_free: bool
650
    is_hybrid: bool
651
    has_noops: bool
652
    supports_mamba_prefix_caching: bool
653
    supports_transcription: bool
654
    supports_transcription_only: bool
655
656

    @staticmethod
657
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
658
        return _ModelInfo(
659
            architecture=model.__name__,
660
            is_text_generation_model=is_text_generation_model(model),
661
            is_pooling_model=is_pooling_model(model),
662
663
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
664
            attn_type=get_attn_type(model),
665
            supports_cross_encoding=supports_cross_encoding(model),
666
            supports_late_interaction=supports_late_interaction(model),
667
            supports_multimodal=supports_multimodal(model),
668
669
670
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
671
            requires_raw_input_tokens=requires_raw_input_tokens(model),
672
673
674
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
675
            supports_pp=supports_pp(model),
676
677
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
678
            is_hybrid=is_hybrid(model),
679
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
680
            supports_transcription=supports_transcription(model),
681
682
683
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
684
            has_noops=has_noops(model),
685
        )
686
687


688
689
690
691
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
692

693
    @abstractmethod
694
    def load_model_cls(self) -> type[nn.Module]:
695
        raise NotImplementedError
696
697


698
699
700
701
702
703
704
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
705
    model_cls: type[nn.Module]
706
707

    @staticmethod
708
    def from_model_cls(model_cls: type[nn.Module]):
709
710
711
712
713
714
715
716
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

717
    def load_model_cls(self) -> type[nn.Module]:
718
719
720
721
722
723
724
725
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
726

727
728
729
    module_name: str
    class_name: str

730
731
732
733
734
735
736
737
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

738
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
739
740
        try:
            try:
741
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
742
743
744
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
745
                logger.debug(
746
                    "Cached model info file for class %s.%s not found",
747
748
749
                    self.module_name,
                    self.class_name,
                )
750
751
752
                return None

            if mi_dict["hash"] != module_hash:
753
                logger.debug(
754
                    "Cached model info file for class %s.%s is stale",
755
756
757
                    self.module_name,
                    self.class_name,
                )
758
759
760
761
762
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
763
            logger.debug(
764
                "Cached model info for class %s.%s error. ",
765
766
767
                self.module_name,
                self.class_name,
            )
768
769
            return None

770
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
771
772
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
773

774
775
776
777
778
779
780
781
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
782
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
783
784
785
786
787
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
788
    def inspect_model_cls(self) -> _ModelInfo:
789
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
790
        module_hash = None
791

792
793
        if model_path.exists():
            with open(model_path, "rb") as f:
794
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
795
796
797

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
798
                logger.debug(
799
                    "Loaded model info for class %s.%s from cache",
800
801
802
                    self.module_name,
                    self.class_name,
                )
803
804
                return mi
            else:
805
                logger.debug(
806
                    "Cache model info for class %s.%s miss. Loading model instead.",
807
808
809
                    self.module_name,
                    self.class_name,
                )
810
811
812

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
813
814
815
816
817
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
818
819

        # save cache file
820
821
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
822
823

        return mi
824

825
    def load_model_cls(self) -> type[nn.Module]:
826
827
828
829
830
831
832
833
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
834
) -> type[nn.Module] | None:
835
    from vllm.platforms import current_platform
836

837
    current_platform.verify_model_arch(model_arch)
838
839
840
    try:
        return model.load_model_cls()
    except Exception:
841
        logger.exception("Error in loading model architecture '%s'", model_arch)
842
        return None
843
844


845
846
847
848
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
849
) -> _ModelInfo | None:
850
851
852
    try:
        return model.inspect_model_cls()
    except Exception:
853
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
854
        return None
855
856


857
858
859
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
860
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
861

862
    def get_supported_archs(self) -> Set[str]:
863
        return self.models.keys()
864

865
866
867
    def register_model(
        self,
        model_arch: str,
868
        model_cls: type[nn.Module] | str,
869
    ) -> None:
870
871
872
        """
        Register an external model to be used in vLLM.

873
        `model_cls` can be either:
874

875
        - A [`torch.nn.Module`][] class directly referencing the model.
876
        - A string in the format `<module>:<class>` which can be used to
877
878
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
879
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
880
        """
881
882
883
884
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

885
        if model_arch in self.models:
886
887
            logger.warning(
                "Model architecture %s is already registered, and will be "
888
889
890
891
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
892
893
894
895
896
897

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
898

899
            model = _LazyRegisteredModel(*split_str)
900
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
901
            model = _RegisteredModel.from_model_cls(model_cls)
902
        else:
903
904
905
906
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
907
            raise TypeError(msg)
908

909
        self.models[model_arch] = model
910

911
    def _raise_for_unsupported(self, architectures: list[str]):
912
        all_supported_archs = self.get_supported_archs()
913

914
915
916
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
917
918
                "to be inspected. Please check the logs for more details."
            )
919

920
921
922
923
924
925
926
927
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
928
929
                    "use this model architecture."
                )
930

931
932
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
933
934
            f"Supported architectures: {all_supported_archs}"
        )
935

936
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
937
938
        if model_arch not in self.models:
            return None
939

940
        return _try_load_model_cls(model_arch, self.models[model_arch])
941

942
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
943
944
        if model_arch not in self.models:
            return None
945

946
947
948
949
950
951
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
952
    ) -> str | None:
953
954
955
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

956
957
958
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
975
                        trust_remote_code=model_config.trust_remote_code,
976
977
978
979
980
981
982
983
984
985
986
987
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
988
                        trust_remote_code=model_config.trust_remote_code,
989
990
991
992
993
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
994
                if model_config.model_impl != "transformers":
995
996
997
998
999
1000
1001
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
1002
1003
                    "'auto_map' (relevant if the model is custom)."
                )
1004
1005

        if not model_module.is_backend_compatible():
1006
            if model_config.model_impl != "transformers":
1007
                return None
1008

1009
1010
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1011
1012
                "is not compatible with vLLM."
            )
1013

1014
        return model_config._get_transformers_backend_cls()
1015

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1041

1042
1043
    def inspect_model_cls(
        self,
1044
        architectures: str | list[str],
1045
        model_config: ModelConfig,
1046
    ) -> tuple[_ModelInfo, str]:
1047
1048
        if isinstance(architectures, str):
            architectures = [architectures]
1049
1050
        if not architectures:
            raise ValueError("No model architectures are specified")
1051
1052

        # Require transformers impl
1053
        if model_config.model_impl == "transformers":
1054
            arch = self._try_resolve_transformers(architectures[0], model_config)
1055
1056
1057
1058
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1059
        elif model_config.model_impl == "terratorch":
1060
1061
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1062

1063
        # Fallback to transformers impl (after resolving convert_type)
1064
1065
1066
1067
1068
1069
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1070
1071
1072
1073
1074
1075
1076
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1077
            model_info = self._try_inspect_model_cls(normalized_arch)
1078
            if model_info is not None:
1079
                return (model_info, arch)
1080

1081
        # Fallback to transformers impl (before resolving runner_type)
1082
1083
1084
1085
1086
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1087
1088
1089
1090
1091
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1092
        return self._raise_for_unsupported(architectures)
1093

1094
1095
    def resolve_model_cls(
        self,
1096
        architectures: str | list[str],
1097
        model_config: ModelConfig,
1098
    ) -> tuple[type[nn.Module], str]:
1099
1100
        if isinstance(architectures, str):
            architectures = [architectures]
1101
1102
        if not architectures:
            raise ValueError("No model architectures are specified")
1103
1104

        # Require transformers impl
1105
        if model_config.model_impl == "transformers":
1106
            arch = self._try_resolve_transformers(architectures[0], model_config)
1107
1108
1109
1110
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1111
        elif model_config.model_impl == "terratorch":
1112
1113
1114
1115
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1116

1117
        # Fallback to transformers impl (after resolving convert_type)
1118
1119
1120
1121
1122
1123
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1124
1125
1126
1127
1128
1129
1130
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1131
            model_cls = self._try_load_model_cls(normalized_arch)
1132
1133
            if model_cls is not None:
                return (model_cls, arch)
1134

1135
        # Fallback to transformers impl (before resolving runner_type)
1136
1137
1138
1139
1140
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1141
1142
1143
1144
1145
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1146
        return self._raise_for_unsupported(architectures)
1147

1148
1149
    def is_text_generation_model(
        self,
1150
        architectures: str | list[str],
1151
        model_config: ModelConfig,
1152
    ) -> bool:
1153
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1154
        return model_cls.is_text_generation_model
1155

1156
    def is_pooling_model(
1157
        self,
1158
        architectures: str | list[str],
1159
        model_config: ModelConfig,
1160
    ) -> bool:
1161
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1162
        return model_cls.is_pooling_model
1163

1164
1165
    def is_cross_encoder_model(
        self,
1166
        architectures: str | list[str],
1167
        model_config: ModelConfig,
1168
    ) -> bool:
1169
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1170
        return model_cls.supports_cross_encoding
1171

1172
1173
    def is_multimodal_model(
        self,
1174
        architectures: str | list[str],
1175
        model_config: ModelConfig,
1176
    ) -> bool:
1177
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1178
        return model_cls.supports_multimodal
1179

1180
    def is_multimodal_raw_input_only_model(
1181
        self,
1182
        architectures: str | list[str],
1183
        model_config: ModelConfig,
1184
    ) -> bool:
1185
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1186
        return model_cls.supports_multimodal_raw_input_only
1187

1188
1189
    def is_pp_supported_model(
        self,
1190
        architectures: str | list[str],
1191
        model_config: ModelConfig,
1192
    ) -> bool:
1193
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1194
        return model_cls.supports_pp
1195

1196
1197
    def model_has_inner_state(
        self,
1198
        architectures: str | list[str],
1199
        model_config: ModelConfig,
1200
    ) -> bool:
1201
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1202
        return model_cls.has_inner_state
1203

1204
1205
    def is_attention_free_model(
        self,
1206
        architectures: str | list[str],
1207
        model_config: ModelConfig,
1208
    ) -> bool:
1209
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1210
        return model_cls.is_attention_free
1211

1212
1213
    def is_hybrid_model(
        self,
1214
        architectures: str | list[str],
1215
        model_config: ModelConfig,
1216
    ) -> bool:
1217
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1218
1219
        return model_cls.is_hybrid

1220
1221
    def is_noops_model(
        self,
1222
        architectures: str | list[str],
1223
        model_config: ModelConfig,
1224
    ) -> bool:
1225
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1226
1227
        return model_cls.has_noops

1228
1229
    def is_transcription_model(
        self,
1230
        architectures: str | list[str],
1231
        model_config: ModelConfig,
1232
    ) -> bool:
1233
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1234
1235
        return model_cls.supports_transcription

1236
1237
    def is_transcription_only_model(
        self,
1238
        architectures: str | list[str],
1239
        model_config: ModelConfig,
1240
    ) -> bool:
1241
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1242
1243
        return model_cls.supports_transcription_only

1244

1245
1246
1247
1248
1249
1250
1251
1252
1253
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1254
1255
1256
1257
1258

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1259
1260
1261
1262
1263
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1264
        # `cloudpickle` allows pickling lambda functions directly
1265
        import cloudpickle
1266

1267
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1268
1269
1270

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1271
1272
1273
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1274
1275
1276
1277
1278
1279

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1280
1281
1282
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1283

1284
        with open(output_filepath, "rb") as f:
1285
1286
1287
1288
1289
1290
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1291

1292
1293
1294
1295
1296
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1297
1298
1299

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1300
1301
1302


if __name__ == "__main__":
1303
    _run()