"vscode:/vscode.git/clone" did not exist on "4561f13985e86ea15c9c672fbcfa28f12530e3e5"
registry.py 50.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
    "AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
79
80
81
82
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
83
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
84
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Jiangyun Zhu's avatar
Jiangyun Zhu committed
85
    "BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
87
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
88
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
89
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
90
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
91
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
92
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
94
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
95
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
96
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
97
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
98
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
99
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
100
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
101
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
102
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
103
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
104
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
105
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
106
107
108
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
109
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
110
111
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
112
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
113
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
114
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
115
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
116
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
117
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
118
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
119
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
120
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
121
122
123
124
125
126
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
127
128
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
129
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
130
131
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
132
133
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
134
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
135
136
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
137
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
138
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
139
140
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
141
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
142
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
143
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
144
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
145
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
146
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
147
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
149
150
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
151
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
152
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
153
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
154
155
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
156
157
158
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
159
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
160
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
161
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
162
163
164
165
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
166
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
167
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
168
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
169
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
170
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
171
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
172
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
173
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
174
175
176
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
177
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
178
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
179
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
180
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
181
182
183
184
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
185
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
186
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
187
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
188
189
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
190
191
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
192
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
193
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
194
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
195
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
196
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
197
198
199
200
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
201
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
202
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
203
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
204
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
205
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
206
207
208
}

_EMBEDDING_MODELS = {
209
    # [Text-only]
210
    "BertModel": ("bert", "BertEmbeddingModel"),
211
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
212
    "HF_ColBERT": ("colbert", "ColBERTModel"),
213
214
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
215
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
216
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
217
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
218
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
219
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
220
    "GritLM": ("gritlm", "GritLM"),
221
222
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
223
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
224
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
225
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
226
    "LlamaModel": ("llama", "LlamaForCausalLM"),
227
228
    **{
        # Multiple models share the same architecture, so we include them all
229
230
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
231
232
        if arch == "LlamaForCausalLM"
    },
233
    "MistralModel": ("llama", "LlamaForCausalLM"),
234
    "ModernBertModel": ("modernbert", "ModernBertModel"),
235
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
236
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
237
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
238
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
239
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
240
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
241
242
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
243
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
244
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
245
246
247
248
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
249
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
250
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
251
    # [Multimodal]
252
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
253
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
254
255
256
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
257
    ),
258
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
259
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
260
261
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
262
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
263
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
264
265
266
267
    "LlamaNemotronVLModel": (
        "nemotron_vl",
        "LlamaNemotronVLForEmbedding",
    ),
268
269
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
270
    # models for the time being.
271
272
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
273
274
}

275
276
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
277
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
278
279
280
281
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
282
283
284
285
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
286
287
288
289
    ),
    "LlamaNemotronVLForSequenceClassification": (
        "nemotron_vl",
        "LlamaNemotronVLForSequenceClassification",
290
    ),
291
292
293
294
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
295
296
297
298
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
299
300
301
302
303
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
304
305
}

306
_MULTIMODAL_MODELS = {
307
    # [Decoder-only]
308
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
309
310
311
312
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
313
314
315
316
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
317
318
319
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
320
    ),
321
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
322
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
323
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
324
325
326
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
327
    ),
328
329
330
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
331
    ),
332
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
333
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
334
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
335
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
336
337
338
339
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
340
341
342
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
343
    ),
344
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
345
346
347
348
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
349
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
350
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
351
352
353
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
354
    ),
355
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
356
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
357
358
359
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
360
361
362
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
363
    ),
364
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
365
366
367
368
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
369
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
370
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
371
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
372
373
374
375
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
376
377
378
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
379
    ),
380
381
382
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
383
    ),
zxy's avatar
zxy committed
384
385
386
387
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
388
389
390
391
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
392
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
393
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
394
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
395
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
396
397
398
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
399
    ),
400
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
401
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
402
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
403
404
405
406
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
407
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
408
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
409
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
410
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
411
412
413
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
414
    ),
415
416
417
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
418
    ),
419
420
421
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
422
    ),
423
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
424
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
425
426
427
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
428
    ),
429
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
430
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
431
432
433
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
434
    ),
435
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
436
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
437
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
438
439
440
441
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
442
    "Ovis": ("ovis", "Ovis"),
443
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
444
445
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
446
447
448
449
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
450
451
452
453
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
454
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
455
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
456
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
457
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
458
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
459
460
461
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
462
    ),
463
464
465
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
466
    ),
467
468
469
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
470
    ),
471
472
473
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
474
    ),
475
476
477
478
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
479
480
481
482
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
483
484
485
486
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
487
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
488
489
490
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
491
    ),
492
493
494
495
496
497
498
499
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
500
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
501
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
502
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
503
504
505
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
506
    ),
507
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
508
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
509
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
510
    # [Encoder-decoder]
511
512
513
514
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
515
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
516
}
517
518

_SPECULATIVE_DECODING_MODELS = {
519
    "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
520
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
521
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
522
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
523
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
524
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
525
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
526
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
527
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
528
529
530
531
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
532
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
533
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
534
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
535
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
536
    "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
XuruiYang's avatar
XuruiYang committed
537
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
538
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
539
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
540
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
541
    "MedusaModel": ("medusa", "Medusa"),
542
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
543
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
544
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
545
546
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
547
548
549
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
550
}
551

552
_TRANSFORMERS_SUPPORTED_MODELS = {
553
554
555
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
556
557
558
559
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
560
561
562
}

_TRANSFORMERS_BACKEND_MODELS = {
563
    # Text generation models
564
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
583
    "TransformersForSequenceClassification": (
584
        "transformers",
585
        "TransformersForSequenceClassification",
586
    ),
587
    "TransformersMoEForSequenceClassification": (
588
        "transformers",
589
        "TransformersMoEForSequenceClassification",
590
    ),
591
592
593
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
594
    ),
595
}
596

597
_VLLM_MODELS = {
598
    **_TEXT_GENERATION_MODELS,
599
    **_EMBEDDING_MODELS,
600
    **_CROSS_ENCODER_MODELS,
601
    **_MULTIMODAL_MODELS,
602
    **_SPECULATIVE_DECODING_MODELS,
603
604
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
605
606
}

607
608
609
610
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
611
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
612

613
_PREVIOUSLY_SUPPORTED_MODELS = {
614
    "MotifForCausalLM": "0.10.2",
615
    "Phi3SmallForCausalLM": "0.9.2",
616
    "Phi4FlashForCausalLM": "0.10.2",
617
    "Phi4MultimodalForCausalLM": "0.12.0",
618
619
620
621
622
623
624
625
626
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
627

628

629
630
@dataclass(frozen=True)
class _ModelInfo:
631
    architecture: str
632
    is_text_generation_model: bool
633
    is_pooling_model: bool
634
    attn_type: AttnTypeStr
635
636
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
637
    supports_cross_encoding: bool
638
    supports_late_interaction: bool
639
    supports_multimodal: bool
640
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
641
    requires_raw_input_tokens: bool
642
    supports_multimodal_encoder_tp_data: bool
643
    supports_pp: bool
644
645
    has_inner_state: bool
    is_attention_free: bool
646
    is_hybrid: bool
647
    has_noops: bool
648
    supports_mamba_prefix_caching: bool
649
    supports_transcription: bool
650
    supports_transcription_only: bool
651
652

    @staticmethod
653
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
654
        return _ModelInfo(
655
            architecture=model.__name__,
656
            is_text_generation_model=is_text_generation_model(model),
657
            is_pooling_model=is_pooling_model(model),
658
659
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
660
            attn_type=get_attn_type(model),
661
            supports_cross_encoding=supports_cross_encoding(model),
662
            supports_late_interaction=supports_late_interaction(model),
663
            supports_multimodal=supports_multimodal(model),
664
665
666
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
667
            requires_raw_input_tokens=requires_raw_input_tokens(model),
668
669
670
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
671
            supports_pp=supports_pp(model),
672
673
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
674
            is_hybrid=is_hybrid(model),
675
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
676
            supports_transcription=supports_transcription(model),
677
678
679
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
680
            has_noops=has_noops(model),
681
        )
682
683


684
685
686
687
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
688

689
    @abstractmethod
690
    def load_model_cls(self) -> type[nn.Module]:
691
        raise NotImplementedError
692
693


694
695
696
697
698
699
700
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
701
    model_cls: type[nn.Module]
702
703

    @staticmethod
704
    def from_model_cls(model_cls: type[nn.Module]):
705
706
707
708
709
710
711
712
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

713
    def load_model_cls(self) -> type[nn.Module]:
714
715
716
717
718
719
720
721
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
722

723
724
725
    module_name: str
    class_name: str

726
727
728
729
730
731
732
733
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

734
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
735
736
        try:
            try:
737
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
738
739
740
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
741
                logger.debug(
742
                    "Cached model info file for class %s.%s not found",
743
744
745
                    self.module_name,
                    self.class_name,
                )
746
747
748
                return None

            if mi_dict["hash"] != module_hash:
749
                logger.debug(
750
                    "Cached model info file for class %s.%s is stale",
751
752
753
                    self.module_name,
                    self.class_name,
                )
754
755
756
757
758
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
759
            logger.debug(
760
                "Cached model info for class %s.%s error. ",
761
762
763
                self.module_name,
                self.class_name,
            )
764
765
            return None

766
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
767
768
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
769

770
771
772
773
774
775
776
777
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
778
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
779
780
781
782
783
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
784
    def inspect_model_cls(self) -> _ModelInfo:
785
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
786
        module_hash = None
787

788
789
        if model_path.exists():
            with open(model_path, "rb") as f:
790
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
791
792
793

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
794
                logger.debug(
795
                    "Loaded model info for class %s.%s from cache",
796
797
798
                    self.module_name,
                    self.class_name,
                )
799
800
                return mi
            else:
801
                logger.debug(
802
                    "Cache model info for class %s.%s miss. Loading model instead.",
803
804
805
                    self.module_name,
                    self.class_name,
                )
806
807
808

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
809
810
811
812
813
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
814
815

        # save cache file
816
817
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
818
819

        return mi
820

821
    def load_model_cls(self) -> type[nn.Module]:
822
823
824
825
826
827
828
829
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
830
) -> type[nn.Module] | None:
831
    from vllm.platforms import current_platform
832

833
    current_platform.verify_model_arch(model_arch)
834
835
836
    try:
        return model.load_model_cls()
    except Exception:
837
        logger.exception("Error in loading model architecture '%s'", model_arch)
838
        return None
839
840


841
842
843
844
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
845
) -> _ModelInfo | None:
846
847
848
    try:
        return model.inspect_model_cls()
    except Exception:
849
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
850
        return None
851
852


853
854
855
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
856
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
857

858
    def get_supported_archs(self) -> Set[str]:
859
        return self.models.keys()
860

861
862
863
    def register_model(
        self,
        model_arch: str,
864
        model_cls: type[nn.Module] | str,
865
    ) -> None:
866
867
868
        """
        Register an external model to be used in vLLM.

869
        `model_cls` can be either:
870

871
        - A [`torch.nn.Module`][] class directly referencing the model.
872
        - A string in the format `<module>:<class>` which can be used to
873
874
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
875
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
876
        """
877
878
879
880
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

881
        if model_arch in self.models:
882
883
            logger.warning(
                "Model architecture %s is already registered, and will be "
884
885
886
887
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
888
889
890
891
892
893

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
894

895
            model = _LazyRegisteredModel(*split_str)
896
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
897
            model = _RegisteredModel.from_model_cls(model_cls)
898
        else:
899
900
901
902
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
903
            raise TypeError(msg)
904

905
        self.models[model_arch] = model
906

907
    def _raise_for_unsupported(self, architectures: list[str]):
908
        all_supported_archs = self.get_supported_archs()
909

910
911
912
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
913
914
                "to be inspected. Please check the logs for more details."
            )
915

916
917
918
919
920
921
922
923
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
924
925
                    "use this model architecture."
                )
926

927
928
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
929
930
            f"Supported architectures: {all_supported_archs}"
        )
931

932
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
933
934
        if model_arch not in self.models:
            return None
935

936
        return _try_load_model_cls(model_arch, self.models[model_arch])
937

938
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
939
940
        if model_arch not in self.models:
            return None
941

942
943
944
945
946
947
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
948
    ) -> str | None:
949
950
951
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

952
953
954
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
971
                        trust_remote_code=model_config.trust_remote_code,
972
973
974
975
976
977
978
979
980
981
982
983
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
984
                        trust_remote_code=model_config.trust_remote_code,
985
986
987
988
989
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
990
                if model_config.model_impl != "transformers":
991
992
993
994
995
996
997
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
998
999
                    "'auto_map' (relevant if the model is custom)."
                )
1000
1001

        if not model_module.is_backend_compatible():
1002
            if model_config.model_impl != "transformers":
1003
                return None
1004

1005
1006
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
1007
1008
                "is not compatible with vLLM."
            )
1009

1010
        return model_config._get_transformers_backend_cls()
1011

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1037

1038
1039
    def inspect_model_cls(
        self,
1040
        architectures: str | list[str],
1041
        model_config: ModelConfig,
1042
    ) -> tuple[_ModelInfo, str]:
1043
1044
        if isinstance(architectures, str):
            architectures = [architectures]
1045
1046
        if not architectures:
            raise ValueError("No model architectures are specified")
1047
1048

        # Require transformers impl
1049
        if model_config.model_impl == "transformers":
1050
            arch = self._try_resolve_transformers(architectures[0], model_config)
1051
1052
1053
1054
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1055
        elif model_config.model_impl == "terratorch":
1056
1057
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1058

1059
        # Fallback to transformers impl (after resolving convert_type)
1060
1061
1062
1063
1064
1065
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1066
1067
1068
1069
1070
1071
1072
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1073
            model_info = self._try_inspect_model_cls(normalized_arch)
1074
            if model_info is not None:
1075
                return (model_info, arch)
1076

1077
        # Fallback to transformers impl (before resolving runner_type)
1078
1079
1080
1081
1082
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1083
1084
1085
1086
1087
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1088
        return self._raise_for_unsupported(architectures)
1089

1090
1091
    def resolve_model_cls(
        self,
1092
        architectures: str | list[str],
1093
        model_config: ModelConfig,
1094
    ) -> tuple[type[nn.Module], str]:
1095
1096
        if isinstance(architectures, str):
            architectures = [architectures]
1097
1098
        if not architectures:
            raise ValueError("No model architectures are specified")
1099
1100

        # Require transformers impl
1101
        if model_config.model_impl == "transformers":
1102
            arch = self._try_resolve_transformers(architectures[0], model_config)
1103
1104
1105
1106
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1107
        elif model_config.model_impl == "terratorch":
1108
1109
1110
1111
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1112

1113
        # Fallback to transformers impl (after resolving convert_type)
1114
1115
1116
1117
1118
1119
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1120
1121
1122
1123
1124
1125
1126
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1127
            model_cls = self._try_load_model_cls(normalized_arch)
1128
1129
            if model_cls is not None:
                return (model_cls, arch)
1130

1131
        # Fallback to transformers impl (before resolving runner_type)
1132
1133
1134
1135
1136
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1137
1138
1139
1140
1141
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1142
        return self._raise_for_unsupported(architectures)
1143

1144
1145
    def is_text_generation_model(
        self,
1146
        architectures: str | list[str],
1147
        model_config: ModelConfig,
1148
    ) -> bool:
1149
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1150
        return model_cls.is_text_generation_model
1151

1152
    def is_pooling_model(
1153
        self,
1154
        architectures: str | list[str],
1155
        model_config: ModelConfig,
1156
    ) -> bool:
1157
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1158
        return model_cls.is_pooling_model
1159

1160
1161
    def is_cross_encoder_model(
        self,
1162
        architectures: str | list[str],
1163
        model_config: ModelConfig,
1164
    ) -> bool:
1165
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1166
        return model_cls.supports_cross_encoding
1167

1168
1169
    def is_multimodal_model(
        self,
1170
        architectures: str | list[str],
1171
        model_config: ModelConfig,
1172
    ) -> bool:
1173
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1174
        return model_cls.supports_multimodal
1175

1176
    def is_multimodal_raw_input_only_model(
1177
        self,
1178
        architectures: str | list[str],
1179
        model_config: ModelConfig,
1180
    ) -> bool:
1181
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1182
        return model_cls.supports_multimodal_raw_input_only
1183

1184
1185
    def is_pp_supported_model(
        self,
1186
        architectures: str | list[str],
1187
        model_config: ModelConfig,
1188
    ) -> bool:
1189
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1190
        return model_cls.supports_pp
1191

1192
1193
    def model_has_inner_state(
        self,
1194
        architectures: str | list[str],
1195
        model_config: ModelConfig,
1196
    ) -> bool:
1197
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1198
        return model_cls.has_inner_state
1199

1200
1201
    def is_attention_free_model(
        self,
1202
        architectures: str | list[str],
1203
        model_config: ModelConfig,
1204
    ) -> bool:
1205
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1206
        return model_cls.is_attention_free
1207

1208
1209
    def is_hybrid_model(
        self,
1210
        architectures: str | list[str],
1211
        model_config: ModelConfig,
1212
    ) -> bool:
1213
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1214
1215
        return model_cls.is_hybrid

1216
1217
    def is_noops_model(
        self,
1218
        architectures: str | list[str],
1219
        model_config: ModelConfig,
1220
    ) -> bool:
1221
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1222
1223
        return model_cls.has_noops

1224
1225
    def is_transcription_model(
        self,
1226
        architectures: str | list[str],
1227
        model_config: ModelConfig,
1228
    ) -> bool:
1229
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1230
1231
        return model_cls.supports_transcription

1232
1233
    def is_transcription_only_model(
        self,
1234
        architectures: str | list[str],
1235
        model_config: ModelConfig,
1236
    ) -> bool:
1237
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1238
1239
        return model_cls.supports_transcription_only

1240

1241
1242
1243
1244
1245
1246
1247
1248
1249
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1250
1251
1252
1253
1254

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1255
1256
1257
1258
1259
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1260
        # `cloudpickle` allows pickling lambda functions directly
1261
        import cloudpickle
1262

1263
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1264
1265
1266

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1267
1268
1269
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1270
1271
1272
1273
1274
1275

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1276
1277
1278
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1279

1280
        with open(output_filepath, "rb") as f:
1281
1282
1283
1284
1285
1286
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1287

1288
1289
1290
1291
1292
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1293
1294
1295

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1296
1297
1298


if __name__ == "__main__":
1299
    _run()