registry.py 37.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
32
33
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
                         supports_multimodal_raw_input, supports_pp,
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
43
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
44
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
45
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
46
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
47
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
50
51
52
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
53
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
54
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
55
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
56
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
57
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
58
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
59
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
60
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
61
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
62
63
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
64
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
65
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
66
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
67
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
68
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
69
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
70
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
71
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
72
73
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
74
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
75
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
76
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
77
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
78
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
79
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
80
81
82
83
84
85
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
86
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
87
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
88
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
89
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
90
91
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
92
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
93
94
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
95
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
96
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
97
98
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
99
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
100
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
101
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
102
103
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
104
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
105
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
106
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
107
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
108
109
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
110
111
112
113
114
115
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
116
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
117
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
118
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
119
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
120
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
121
122
123
124
125
126
127
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
128
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
129
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
130
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
131
132
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
133
134
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
135
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
136
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
137
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
138
139
140
141
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
142
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
143
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
144
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
145
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
146
147
148
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
149
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
150
151
152
}

_EMBEDDING_MODELS = {
153
    # [Text-only]
154
    "BertModel": ("bert", "BertEmbeddingModel"),
155
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
156
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
157
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
158
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
159
    "GritLM": ("gritlm", "GritLM"),
160
161
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
162
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
163
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
164
    "LlamaModel": ("llama", "LlamaForCausalLM"),
165
166
167
168
169
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
170
    "MistralModel": ("llama", "LlamaForCausalLM"),
171
    "ModernBertModel": ("modernbert", "ModernBertModel"),
172
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
173
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
174
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
175
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
176
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
177
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
178
179
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
180
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
181
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
182
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
183
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
184
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
185
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
186
187
188
189
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
190
191
}

192
193
194
195
196
197
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
198
199
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
200
    # [Auto-converted (see adapters.py)]
201
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
202
203
}

204
_MULTIMODAL_MODELS = {
205
    # [Decoder-only]
206
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
207
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
208
209
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
210
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
211
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
212
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
213
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
214
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
215
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
216
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
217
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
218
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
219
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
220
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
221
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
222
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
223
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
224
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
225
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
226
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
227
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
228
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
229
230
231
232
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
233
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
234
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
235
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
236
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
237
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
238
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
239
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
240
    "Ovis": ("ovis", "Ovis"),
241
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
242
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
243
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
244
245
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
246
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
247
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
248
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
249
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
250
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
251
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
252
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
253
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
254
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
255
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
256
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
257
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
258
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
259
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
260
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
261
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
262
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
263
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
264
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
265
}
266
267

_SPECULATIVE_DECODING_MODELS = {
268
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
269
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
270
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
271
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
272
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
273
274
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
275
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
276
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
277
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
278
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
279
    "MedusaModel": ("medusa", "Medusa"),
280
281
282
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
283
}
284

285
_TRANSFORMERS_SUPPORTED_MODELS = {
286
287
288
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
289
290
291
292
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
293
    "TransformersModel": ("transformers", "TransformersModel"),
294
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
295
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
296
}
297
# yapf: enable
298

299
_VLLM_MODELS = {
300
    **_TEXT_GENERATION_MODELS,
301
    **_EMBEDDING_MODELS,
302
    **_CROSS_ENCODER_MODELS,
303
    **_MULTIMODAL_MODELS,
304
    **_SPECULATIVE_DECODING_MODELS,
305
306
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
307
308
}

309
310
311
312
313
314
315
316
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

317
318
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

319

320
321
@dataclass(frozen=True)
class _ModelInfo:
322
    architecture: str
323
    is_text_generation_model: bool
324
    is_pooling_model: bool
325
    default_pooling_type: str
326
    supports_cross_encoding: bool
327
    supports_multimodal: bool
328
    supports_multimodal_raw_input: bool
329
    supports_multimodal_encoder_tp_data: bool
330
    supports_pp: bool
331
332
    has_inner_state: bool
    is_attention_free: bool
333
    is_hybrid: bool
334
    has_noops: bool
335
    supports_transcription: bool
336
    supports_transcription_only: bool
337
    supports_v0_only: bool
338
339

    @staticmethod
340
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
341
        return _ModelInfo(
342
            architecture=model.__name__,
343
            is_text_generation_model=is_text_generation_model(model),
344
            is_pooling_model=is_pooling_model(model),
345
            default_pooling_type=get_default_pooling_type(model),
346
            supports_cross_encoding=supports_cross_encoding(model),
347
            supports_multimodal=supports_multimodal(model),
348
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
349
350
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
351
            supports_pp=supports_pp(model),
352
353
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
354
            is_hybrid=is_hybrid(model),
355
            supports_transcription=supports_transcription(model),
356
357
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
358
            supports_v0_only=supports_v0_only(model),
359
            has_noops=has_noops(model),
360
        )
361
362


363
class _BaseRegisteredModel(ABC):
364

365
366
367
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
368

369
    @abstractmethod
370
    def load_model_cls(self) -> type[nn.Module]:
371
        raise NotImplementedError
372
373


374
375
376
377
378
379
380
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
381
    model_cls: type[nn.Module]
382
383

    @staticmethod
384
    def from_model_cls(model_cls: type[nn.Module]):
385
386
387
388
389
390
391
392
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

393
    def load_model_cls(self) -> type[nn.Module]:
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

410
    def load_model_cls(self) -> type[nn.Module]:
411
412
413
414
415
416
417
418
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
419
) -> Optional[type[nn.Module]]:
420
    from vllm.platforms import current_platform
421
    current_platform.verify_model_arch(model_arch)
422
423
424
425
426
427
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
428
429


430
431
432
433
434
435
436
437
438
439
440
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
441
442


443
444
445
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
446
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
447

448
    def get_supported_archs(self) -> Set[str]:
449
        return self.models.keys()
450

451
452
453
    def register_model(
        self,
        model_arch: str,
454
        model_cls: Union[type[nn.Module], str],
455
    ) -> None:
456
457
458
        """
        Register an external model to be used in vLLM.

459
        `model_cls` can be either:
460

461
        - A [`torch.nn.Module`][] class directly referencing the model.
462
        - A string in the format `<module>:<class>` which can be used to
463
464
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
465
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
466
        """
467
468
469
470
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

471
        if model_arch in self.models:
472
473
474
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
475
476
477
478
479
480
481
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
482

483
            model = _LazyRegisteredModel(*split_str)
484
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
485
            model = _RegisteredModel.from_model_cls(model_cls)
486
487
488
489
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
490

491
        self.models[model_arch] = model
492

493
    def _raise_for_unsupported(self, architectures: list[str]):
494
        all_supported_archs = self.get_supported_archs()
495

496
497
498
499
500
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

501
502
503
504
505
506
507
508
509
510
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

511
512
513
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
514

515
    def _try_load_model_cls(self,
516
                            model_arch: str) -> Optional[type[nn.Module]]:
517
518
        if model_arch not in self.models:
            return None
519

520
        return _try_load_model_cls(model_arch, self.models[model_arch])
521

522
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
523
524
        if model_arch not in self.models:
            return None
525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
583
                return None
584

585
586
587
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
588

589
        return model_config._get_transformers_backend_cls()
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
616

617
618
    def inspect_model_cls(
        self,
619
        architectures: Union[str, list[str]],
620
        model_config: ModelConfig,
621
    ) -> tuple[_ModelInfo, str]:
622
623
        if isinstance(architectures, str):
            architectures = [architectures]
624
625
        if not architectures:
            raise ValueError("No model architectures are specified")
626
627
628
629
630
631
632
633
634
635

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

636
637
638
639
640
641
642
643
644
645
646
647
648
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
649
            model_info = self._try_inspect_model_cls(normalized_arch)
650
            if model_info is not None:
651
                return (model_info, arch)
652

653
654
655
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
656
657
658
659
660
661
662
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

663
        return self._raise_for_unsupported(architectures)
664

665
666
    def resolve_model_cls(
        self,
667
        architectures: Union[str, list[str]],
668
        model_config: ModelConfig,
669
    ) -> tuple[type[nn.Module], str]:
670
671
        if isinstance(architectures, str):
            architectures = [architectures]
672
673
        if not architectures:
            raise ValueError("No model architectures are specified")
674
675
676
677
678
679
680
681
682
683

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

684
685
686
687
688
689
690
691
692
693
694
695
696
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
697
            model_cls = self._try_load_model_cls(normalized_arch)
698
699
            if model_cls is not None:
                return (model_cls, arch)
700

701
702
703
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
704
705
706
707
708
709
710
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

711
        return self._raise_for_unsupported(architectures)
712

713
714
    def is_text_generation_model(
        self,
715
        architectures: Union[str, list[str]],
716
        model_config: ModelConfig,
717
    ) -> bool:
718
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
719
        return model_cls.is_text_generation_model
720

721
    def is_pooling_model(
722
        self,
723
        architectures: Union[str, list[str]],
724
        model_config: ModelConfig,
725
    ) -> bool:
726
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
727
        return model_cls.is_pooling_model
728

729
730
    def is_cross_encoder_model(
        self,
731
        architectures: Union[str, list[str]],
732
        model_config: ModelConfig,
733
    ) -> bool:
734
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
735
        return model_cls.supports_cross_encoding
736

737
738
    def is_multimodal_model(
        self,
739
        architectures: Union[str, list[str]],
740
        model_config: ModelConfig,
741
    ) -> bool:
742
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
743
        return model_cls.supports_multimodal
744

745
746
747
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
748
        model_config: ModelConfig,
749
    ) -> bool:
750
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
751
752
        return model_cls.supports_multimodal_raw_input

753
754
    def is_pp_supported_model(
        self,
755
        architectures: Union[str, list[str]],
756
        model_config: ModelConfig,
757
    ) -> bool:
758
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
759
        return model_cls.supports_pp
760

761
762
    def model_has_inner_state(
        self,
763
        architectures: Union[str, list[str]],
764
        model_config: ModelConfig,
765
    ) -> bool:
766
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
767
        return model_cls.has_inner_state
768

769
770
    def is_attention_free_model(
        self,
771
        architectures: Union[str, list[str]],
772
        model_config: ModelConfig,
773
    ) -> bool:
774
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
775
        return model_cls.is_attention_free
776

777
778
    def is_hybrid_model(
        self,
779
        architectures: Union[str, list[str]],
780
        model_config: ModelConfig,
781
    ) -> bool:
782
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
783
784
        return model_cls.is_hybrid

785
786
    def is_noops_model(
        self,
787
        architectures: Union[str, list[str]],
788
        model_config: ModelConfig,
789
    ) -> bool:
790
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
791
792
        return model_cls.has_noops

793
794
    def is_transcription_model(
        self,
795
        architectures: Union[str, list[str]],
796
        model_config: ModelConfig,
797
    ) -> bool:
798
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
799
800
        return model_cls.supports_transcription

801
802
803
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
804
        model_config: ModelConfig,
805
    ) -> bool:
806
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
807
808
        return model_cls.supports_transcription_only

809
810
    def is_v1_compatible(
        self,
811
        architectures: Union[str, list[str]],
812
        model_config: ModelConfig,
813
    ) -> bool:
814
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
815
816
        return not model_cls.supports_v0_only

817
818

ModelRegistry = _ModelRegistry({
819
820
    model_arch:
    _LazyRegisteredModel(
821
822
823
824
825
826
827
828
829
830
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
831
832
833
834
835
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

836
        # `cloudpickle` allows pickling lambda functions directly
837
        import cloudpickle
838
        input_bytes = cloudpickle.dumps((fn, output_filepath))
839
840
841

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
842
843
844
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
845
846
847
848
849
850
851
852
853

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

854
        with open(output_filepath, "rb") as f:
855
856
857
858
859
860
861
862
863
864
865
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
866
867
868

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
869
870
871


if __name__ == "__main__":
872
    _run()