registry.py 38.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
55
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
56
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
57
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
58
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
59
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
60
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
62
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
63
64
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
65
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
66
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
67
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
68
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
69
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
70
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
71
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
72
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
73
74
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
75
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
76
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
77
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
78
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
79
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
80
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
81
82
83
84
85
86
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
87
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
88
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
89
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
90
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
91
92
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
93
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
94
95
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
96
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
97
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
98
99
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
100
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
101
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
102
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
103
104
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
105
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
106
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
107
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
108
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
109
110
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
111
112
113
114
115
116
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
117
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
118
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
119
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
120
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
121
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
122
123
124
125
126
127
128
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
129
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
130
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
131
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
132
133
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
134
135
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
136
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
137
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
138
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
139
140
141
142
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
143
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
144
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
145
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
147
148
149
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
150
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
151
152
153
}

_EMBEDDING_MODELS = {
154
    # [Text-only]
155
    "BertModel": ("bert", "BertEmbeddingModel"),
156
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
157
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
158
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
159
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
160
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
161
    "GritLM": ("gritlm", "GritLM"),
162
163
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
164
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
165
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
166
    "LlamaModel": ("llama", "LlamaForCausalLM"),
167
168
169
170
171
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
172
    "MistralModel": ("llama", "LlamaForCausalLM"),
173
    "ModernBertModel": ("modernbert", "ModernBertModel"),
174
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
175
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
176
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
177
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
178
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
179
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
180
181
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
182
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
183
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
184
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
185
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
186
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
187
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
188
189
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
190
    # models for the time being.
191
192
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
193
194
}

195
196
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
197
198
199
200
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
201
202
203
204
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
205
    # [Auto-converted (see adapters.py)]
206
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
207
208
}

209
_MULTIMODAL_MODELS = {
210
    # [Decoder-only]
211
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
212
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
213
214
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
215
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
216
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
217
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
218
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
219
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
220
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
221
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
222
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
223
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
224
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
225
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
226
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
227
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
228
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
229
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
230
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
231
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
232
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
233
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
234
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
235
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
236
237
238
239
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
240
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
241
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
242
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
243
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
244
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
245
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
246
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
247
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
248
    "Ovis": ("ovis", "Ovis"),
249
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
250
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
251
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
252
253
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
254
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
255
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
256
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
257
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
258
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
259
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
260
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
261
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
262
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
263
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
264
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
265
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
266
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
267
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
268
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
269
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
270
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
271
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
272
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
273
}
274
275

_SPECULATIVE_DECODING_MODELS = {
276
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
277
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
278
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
279
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
280
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
281
282
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
283
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
284
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
285
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
286
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
287
    "MedusaModel": ("medusa", "Medusa"),
288
289
290
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
291
}
292

293
_TRANSFORMERS_SUPPORTED_MODELS = {
294
295
296
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
297
298
299
300
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
301
    "TransformersModel": ("transformers", "TransformersModel"),
302
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
303
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
304
}
305
# yapf: enable
306

307
_VLLM_MODELS = {
308
    **_TEXT_GENERATION_MODELS,
309
    **_EMBEDDING_MODELS,
310
    **_CROSS_ENCODER_MODELS,
311
    **_MULTIMODAL_MODELS,
312
    **_SPECULATIVE_DECODING_MODELS,
313
314
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
315
316
}

317
318
319
320
321
322
323
324
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

325
326
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

327

328
329
@dataclass(frozen=True)
class _ModelInfo:
330
    architecture: str
331
    is_text_generation_model: bool
332
    is_pooling_model: bool
333
    default_pooling_type: str
334
    supports_cross_encoding: bool
335
    supports_multimodal: bool
336
    supports_multimodal_raw_input_only: bool
337
    supports_multimodal_encoder_tp_data: bool
338
    supports_pp: bool
339
340
    has_inner_state: bool
    is_attention_free: bool
341
    is_hybrid: bool
342
    has_noops: bool
343
    supports_transcription: bool
344
    supports_transcription_only: bool
345
    supports_v0_only: bool
346
347

    @staticmethod
348
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
349
        return _ModelInfo(
350
            architecture=model.__name__,
351
            is_text_generation_model=is_text_generation_model(model),
352
            is_pooling_model=is_pooling_model(model),
353
            default_pooling_type=get_default_pooling_type(model),
354
            supports_cross_encoding=supports_cross_encoding(model),
355
            supports_multimodal=supports_multimodal(model),
356
357
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
358
359
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
360
            supports_pp=supports_pp(model),
361
362
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
363
            is_hybrid=is_hybrid(model),
364
            supports_transcription=supports_transcription(model),
365
366
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
367
            supports_v0_only=supports_v0_only(model),
368
            has_noops=has_noops(model),
369
        )
370
371


372
class _BaseRegisteredModel(ABC):
373

374
375
376
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
377

378
    @abstractmethod
379
    def load_model_cls(self) -> type[nn.Module]:
380
        raise NotImplementedError
381
382


383
384
385
386
387
388
389
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
390
    model_cls: type[nn.Module]
391
392

    @staticmethod
393
    def from_model_cls(model_cls: type[nn.Module]):
394
395
396
397
398
399
400
401
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

402
    def load_model_cls(self) -> type[nn.Module]:
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

419
    def load_model_cls(self) -> type[nn.Module]:
420
421
422
423
424
425
426
427
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
428
) -> Optional[type[nn.Module]]:
429
    from vllm.platforms import current_platform
430
    current_platform.verify_model_arch(model_arch)
431
432
433
434
435
436
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
437
438


439
440
441
442
443
444
445
446
447
448
449
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
450
451


452
453
454
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
455
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
456

457
    def get_supported_archs(self) -> Set[str]:
458
        return self.models.keys()
459

460
461
462
    def register_model(
        self,
        model_arch: str,
463
        model_cls: Union[type[nn.Module], str],
464
    ) -> None:
465
466
467
        """
        Register an external model to be used in vLLM.

468
        `model_cls` can be either:
469

470
        - A [`torch.nn.Module`][] class directly referencing the model.
471
        - A string in the format `<module>:<class>` which can be used to
472
473
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
474
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
475
        """
476
477
478
479
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

480
        if model_arch in self.models:
481
482
483
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
484
485
486
487
488
489
490
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
491

492
            model = _LazyRegisteredModel(*split_str)
493
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
494
            model = _RegisteredModel.from_model_cls(model_cls)
495
496
497
498
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
499

500
        self.models[model_arch] = model
501

502
    def _raise_for_unsupported(self, architectures: list[str]):
503
        all_supported_archs = self.get_supported_archs()
504

505
506
507
508
509
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

510
511
512
513
514
515
516
517
518
519
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

520
521
522
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
523

524
    def _try_load_model_cls(self,
525
                            model_arch: str) -> Optional[type[nn.Module]]:
526
527
        if model_arch not in self.models:
            return None
528

529
        return _try_load_model_cls(model_arch, self.models[model_arch])
530

531
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
532
533
        if model_arch not in self.models:
            return None
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
592
                return None
593

594
595
596
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
597

598
        return model_config._get_transformers_backend_cls()
599

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
625

626
627
    def inspect_model_cls(
        self,
628
        architectures: Union[str, list[str]],
629
        model_config: ModelConfig,
630
    ) -> tuple[_ModelInfo, str]:
631
632
        if isinstance(architectures, str):
            architectures = [architectures]
633
634
        if not architectures:
            raise ValueError("No model architectures are specified")
635
636
637
638
639
640
641
642
643

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
644
645
646
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
647

648
649
650
651
652
653
654
655
656
657
658
659
660
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
661
            model_info = self._try_inspect_model_cls(normalized_arch)
662
            if model_info is not None:
663
                return (model_info, arch)
664

665
666
667
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
668
669
670
671
672
673
674
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

675
        return self._raise_for_unsupported(architectures)
676

677
678
    def resolve_model_cls(
        self,
679
        architectures: Union[str, list[str]],
680
        model_config: ModelConfig,
681
    ) -> tuple[type[nn.Module], str]:
682
683
        if isinstance(architectures, str):
            architectures = [architectures]
684
685
        if not architectures:
            raise ValueError("No model architectures are specified")
686
687
688
689
690
691
692
693
694

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
695
696
697
698
699
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
700

701
702
703
704
705
706
707
708
709
710
711
712
713
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
714
            model_cls = self._try_load_model_cls(normalized_arch)
715
716
            if model_cls is not None:
                return (model_cls, arch)
717

718
719
720
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
721
722
723
724
725
726
727
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

728
        return self._raise_for_unsupported(architectures)
729

730
731
    def is_text_generation_model(
        self,
732
        architectures: Union[str, list[str]],
733
        model_config: ModelConfig,
734
    ) -> bool:
735
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
736
        return model_cls.is_text_generation_model
737

738
    def is_pooling_model(
739
        self,
740
        architectures: Union[str, list[str]],
741
        model_config: ModelConfig,
742
    ) -> bool:
743
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
744
        return model_cls.is_pooling_model
745

746
747
    def is_cross_encoder_model(
        self,
748
        architectures: Union[str, list[str]],
749
        model_config: ModelConfig,
750
    ) -> bool:
751
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
752
        return model_cls.supports_cross_encoding
753

754
755
    def is_multimodal_model(
        self,
756
        architectures: Union[str, list[str]],
757
        model_config: ModelConfig,
758
    ) -> bool:
759
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
760
        return model_cls.supports_multimodal
761

762
    def is_multimodal_raw_input_only_model(
763
764
        self,
        architectures: Union[str, list[str]],
765
        model_config: ModelConfig,
766
    ) -> bool:
767
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
768
        return model_cls.supports_multimodal_raw_input_only
769

770
771
    def is_pp_supported_model(
        self,
772
        architectures: Union[str, list[str]],
773
        model_config: ModelConfig,
774
    ) -> bool:
775
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
776
        return model_cls.supports_pp
777

778
779
    def model_has_inner_state(
        self,
780
        architectures: Union[str, list[str]],
781
        model_config: ModelConfig,
782
    ) -> bool:
783
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
784
        return model_cls.has_inner_state
785

786
787
    def is_attention_free_model(
        self,
788
        architectures: Union[str, list[str]],
789
        model_config: ModelConfig,
790
    ) -> bool:
791
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
792
        return model_cls.is_attention_free
793

794
795
    def is_hybrid_model(
        self,
796
        architectures: Union[str, list[str]],
797
        model_config: ModelConfig,
798
    ) -> bool:
799
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
800
801
        return model_cls.is_hybrid

802
803
    def is_noops_model(
        self,
804
        architectures: Union[str, list[str]],
805
        model_config: ModelConfig,
806
    ) -> bool:
807
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
808
809
        return model_cls.has_noops

810
811
    def is_transcription_model(
        self,
812
        architectures: Union[str, list[str]],
813
        model_config: ModelConfig,
814
    ) -> bool:
815
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
816
817
        return model_cls.supports_transcription

818
819
820
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
821
        model_config: ModelConfig,
822
    ) -> bool:
823
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
824
825
        return model_cls.supports_transcription_only

826
827
    def is_v1_compatible(
        self,
828
        architectures: Union[str, list[str]],
829
        model_config: ModelConfig,
830
    ) -> bool:
831
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
832
833
        return not model_cls.supports_v0_only

834
835

ModelRegistry = _ModelRegistry({
836
837
    model_arch:
    _LazyRegisteredModel(
838
839
840
841
842
843
844
845
846
847
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
848
849
850
851
852
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

853
        # `cloudpickle` allows pickling lambda functions directly
854
        import cloudpickle
855
        input_bytes = cloudpickle.dumps((fn, output_filepath))
856
857
858

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
859
860
861
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
862
863
864
865
866
867
868
869
870

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

871
        with open(output_filepath, "rb") as f:
872
873
874
875
876
877
878
879
880
881
882
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
883
884
885

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
886
887
888


if __name__ == "__main__":
889
    _run()